commit 76a9869339daeae42b797bbbe63783261b729cdd Author: Graeme Connell Date: Thu Oct 20 16:52:03 2022 -0600 Squashed history. diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..af0e44f --- /dev/null +++ b/.dockerignore @@ -0,0 +1,5 @@ +.gopath +.gocache +.git +enclave/build +enclave/core.* diff --git a/.github/workflows/dockercache/action.yml b/.github/workflows/dockercache/action.yml new file mode 100644 index 0000000..b98a4dc --- /dev/null +++ b/.github/workflows/dockercache/action.yml @@ -0,0 +1,41 @@ +name: Docker Caching +description: Cache a docker image + +inputs: + dockerdir: + required: true + type: string + imagename: + required: true + type: string + target: + required: false + type: string + dockerfile: + required: true + type: string + +runs: + using: composite + + steps: + - name: Check for cached docker image + id: cached-docker + uses: actions/cache@v3 + with: + path: dockerimage-${{ hashFiles(inputs.dockerfile) }}.tar + key: ${{ runner.os }}-dockerimagetar-${{ hashFiles(inputs.dockerfile) }} + restore-keys: ${{ runner.os }}-dockerimagetar- + + - name: Load docker image + run: docker load --input dockerimage-*.tar || true + shell: bash + + - name: Build/label docker image + run: docker build -t ${{ inputs.imagename }} -f ${{ inputs.dockerfile }} ${{ inputs.dockerdir }} --target=${{ inputs.target }} --cache-from ${{ inputs.imagename }}:latest + shell: bash + + - name: Save docker image + if: steps.cached-docker.outputs.cache-hit != 'true' + run: docker save --output dockerimage-${{ hashFiles(inputs.dockerfile) }}.tar ${{ inputs.imagename }}:latest $(docker history -q ${{ inputs.imagename }}:latest | grep -v missing) + shell: bash diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml new file mode 100644 index 0000000..a190f98 --- /dev/null +++ b/.github/workflows/push.yml @@ -0,0 +1,42 @@ +name: Build and push Docker image + +on: + release: + types: [published] + +jobs: + build-and-push: + runs-on: ubuntu-latest + permissions: + id-token: write # This is required for requesting the JWT + contents: read # This is required for actions/checkout + steps: + - name: Checkout main project + uses: actions/checkout@v3 + with: + submodules: true + + - name: Docker cache + uses: ./.github/workflows/dockercache + with: + dockerdir: . + imagename: svr2_buildenv + target: builder + dockerfile: docker/Dockerfile + + - name: 'Az CLI login' + uses: azure/login@v1 + with: + client-id: ${{ secrets.AZURE_CLIENT_ID }} + tenant-id: ${{ secrets.AZURE_TENANT_ID }} + subscription-id: ${{ secrets.AZURE_SUBSCRIPTION_ID }} + + - name: 'Docker login' + run: + az acr login --name ${{ secrets.AZURE_CONTAINER_REGISTRY_NAME }} + + - name: Build and push container image + run: | + make container + docker tag svr2_runenv:latest "${{ secrets.REGISTRY_LOGIN_SERVER }}/svr2:${GITHUB_REF_NAME}" + docker push "${{ secrets.REGISTRY_LOGIN_SERVER }}/svr2:${GITHUB_REF_NAME}" diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..3518323 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,29 @@ +name: CI +on: [push] + +jobs: + test: + runs-on: ubuntu-latest + permissions: + packages: read + contents: read + + steps: + - name: Checkout main project + uses: actions/checkout@v3 + with: + submodules: true + + - name: Docker cache + uses: ./.github/workflows/dockercache + with: + dockerdir: . + imagename: svr2_buildenv + target: builder + dockerfile: docker/Dockerfile + + - name: Build and test + run: make + + - name: Validate + run: make docker_validate diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6836034 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +*.sw? +.gocache +.gopath +**/.devcontainer +**/.vscode +**/*.code-workspace +.tool-versions +.idea diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..855d0d0 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,18 @@ +[submodule "enclave/protobuf"] + path = enclave/protobuf + url = https://github.com/protocolbuffers/protobuf.git +[submodule "enclave/noise-c"] + path = enclave/noise-c + url = https://github.com/rweather/noise-c.git +[submodule "enclave/SipHash"] + path = enclave/SipHash + url = https://github.com/veorq/SipHash +[submodule "enclave/googletest"] + path = enclave/googletest + url = https://github.com/google/googletest +[submodule "enclave/libsodium"] + path = enclave/libsodium + url = https://github.com/jedisct1/libsodium +[submodule "docker/aws-nitro-enclaves-nsm-api"] + path = docker/aws-nitro-enclaves-nsm-api + url = https://github.com/aws/aws-nitro-enclaves-nsm-api.git diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..be3f7b2 --- /dev/null +++ b/LICENSE @@ -0,0 +1,661 @@ + GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU Affero General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..67cb42a --- /dev/null +++ b/Makefile @@ -0,0 +1,68 @@ +dockall: docker_all + +all: validate host enclave control + +MAKE_ARGS ?= --keep-going + +enclave_testbin: | git + $(MAKE) $(MAKE_ARGS) -C enclave build/enclave.test + +validate: + $(MAKE) $(MAKE_ARGS) -C enclave validate + $(MAKE) $(MAKE_ARGS) -C host validate + ./check_copyrights.sh + +git: + git submodule init || true + git submodule update || true + +enclave: enclave_testbin + $(MAKE) $(MAKE_ARGS) -C enclave all + +enclave_test: + $(MAKE) $(MAKE_ARGS) -C enclave test + +host: enclave_testbin + $(MAKE) $(MAKE_ARGS) -C host all + +control: + $(MAKE) $(MAKE_ARGS) -C host control + +clean: + $(MAKE) $(MAKE_ARGS) -C enclave clean + $(MAKE) $(MAKE_ARGS) -C host clean + +dockerbase: | git + docker build -f docker/Dockerfile -t svr2_buildenv --target=builder . + +PARALLEL ?= $(shell cat /proc/cpuinfo | grep '^cpu cores' | awk '{ sum += $$4 } END { print sum }') +DOCKER_MAKE_ARGS ?= -j$(PARALLEL) MAKE_ARGS="$(MAKE_ARGS)" +ARCH ?= $(shell arch) +ifeq ($(ARCH),arm64) + DOCKER_MAKE_ARGS += 'GO_TEST_FLAGS=-short' # long tests can cause qemu crashes in x86 emulation +endif +DOCKER_ARGS ?= +docker_%: dockerbase + docker run \ + -v "$$(pwd):/src" \ + -u "$$(id -u):$$(id -g)" \ + $(DOCKER_ARGS) \ + svr2_buildenv /bin/bash -c "make V=$(V) $(DOCKER_MAKE_ARGS) $*" + +dockersh: dockerbase + docker run --rm -it \ + -v "$$(pwd):/src" \ + -u "$$(id -u):$$(id -g)" \ + -e "TERM=xterm-256color" \ + $(DOCKER_ARGS) \ + svr2_buildenv + +container: dockerbase + docker build -f docker/Dockerfile -t svr2_runenv . + +enclave_release: docker_enclave_releaser +enclave_releaser: enclave host # depends on 'host' so its tests will run + cp -vn enclave/build/enclave.signed "enclave/releases/default.$$(/opt/openenclave/bin/oesign dump -e enclave/build/enclave.signed | fgrep -i mrenclave | cut -d '=' -f2)" + cp -vn enclave/build/enclave.small "enclave/releases/small.$$(/opt/openenclave/bin/oesign dump -e enclave/build/enclave.small | fgrep -i mrenclave | cut -d '=' -f2)" + +.PHONY: all clean enclave host dockersh docker dockerbase git validate enclave_testbin control enclave_release enclave_releaser diff --git a/README.md b/README.md new file mode 100644 index 0000000..e1b58c8 --- /dev/null +++ b/README.md @@ -0,0 +1,78 @@ +# Secure Value Recovery Service v2 + +The SecureValueRecovery2 (SVR2) project aims to store client-side secrets +server-side protected by a human-remembered (and thus, low-entropy) pin. +It does so by limiting the number of attempts to recover such a secret to +a very small guess count, to disallow brute-force attacks that would otherwise +trivially recover such a secret. To limit the number of recovery attempts, +SVR2 keeps persistent state on the guess count, along with the secret itself, +in a multi-replica, strong-consensus, shared storage mechanism based on +in-memory Raft. + +SVR2 is designed, first and foremost, to not leak the secret +material, and, secondarily, to provide the material back to clients. Given +this, if there is a choice between "lose the secret material forever" and +"store the secret material but potentially leak it", we'll choose the former. +This means that, in some cases, we've chosen to allow the system to lose +_liveness_ (the ability to serve back anything) in order to maintain the +security properties of the system. We'll happily discard every secret in the +system rather than expose one of the secrets to a leak. + +## History + +SVR2 is a successor to the +[SecureValueRecovery](https://github.com/signalapp/SecureValueRecovery) +project that Signal already uses for the above stated purpose. We've built +a second version of this system to handle a few specific issues: + +- Update to SGX DCAP capabilities +- Provide better operational handling of crashes/failures via self-healing +- Simplify to a single-replica-group model since SGX CPUs now have an EPC size of hundreds of gigabytes + +As part of SGX DCAP updates, this project also attempts to be as safe as +possible while running on SGX TME memory, compared to the differing +security guarantees of the SGX MEE memory utilized in the original version. + +## Building + +In order to build and test everything in this repository, you should be able to +just run `make` at the top-level. You must have a valid `docker` installed +locally to do this. Running this at the top-level will: + +- Create a docker image in which to build things +- Build `enclave/enclave.test` (a debug enclave for simulation/testing) and + `enclave/enclave.signed` (a production enclave) +- Build and test the host-side process in `host/` + +If you'd like to incrementally build and change things, you can do so by +running `make dockersh`. This will build the aforementioned docker image, +then drop you inside of it in a `bash` shell. You can then run any of + +``` +make all # Make everything +make enclave # Make all of the enclave stuf +make host # Make all of the host stuff +(cd enclave && make $SOMETARGET) # Make just a specific target in enclave +(cd host && make $SOMETARGET) # Make just a specific target in host +``` + +## Code layout + +Code is divided into a few main directories at the top-level + +* `docker` - Contains the spec for the docker image used to build everything else. +* `shared` - Contains all code/configs that must be shared between the host and enclave. + This includes any protos that the host and enclave use to communicate, + and the definitions of ocalls/ecalls (the `*.edl` files). +* `enclave` - Contains all code and build rules for building the in-enclave binary. + This is a C++ codebase. +* `host` - Contains all code and build rules for building the host-side binary, which + starts up an enclave, then communicates with it. This is a Go codebase. +* `docs` - Contains additional documentation above and beyond the host/enclave `README.md` + docs on specific topics. + +## License + +Copyright 2023 Signal Messenger, LLC + +Licensed under the [AGPLv3](LICENSE) diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..2c23595 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,11 @@ +## Reporting a Vulnerability + +If you've found a security vulnerability in this repository, +please report it via email to . + +Please only use this address to report security flaws in the Signal application (including this +repository). For questions, support, or feature requests concerning the app, please submit a +[support request][] or join the [unofficial community forum][]. + +[support request]: https://support.signal.org/hc/requests/new +[unofficial community forum]: https://community.signalusers.org/ diff --git a/check_copyrights.sh b/check_copyrights.sh new file mode 100755 index 0000000..0ca71bf --- /dev/null +++ b/check_copyrights.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +OUT=0 +for pattern in '*.c' '*.cc' '*.h' '*.go' '*.proto'; do + for file in `find ./ -name $pattern -type f | grep -v -f <(cat .gitmodules | grep path | awk '{print $3}') | egrep -v 'gopath|enclave/build|host/enclave/c'`; do + if ! grep -q Copyright $file; then + OUT=1 + echo "Missing copyright in '$file'" 1>&2 + fi + done +done +exit $OUT diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000..4a59652 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,139 @@ +# syntax=docker/dockerfile:1 +# To build use: +# docker build -t oebuild . +FROM amd64/debian:bullseye-20220912 AS base + +LABEL description="linux build environment for sgx." + +COPY docker/apt.conf docker/sources.list /etc/apt/ +RUN apt-get update && \ + apt-get -y install \ + gpg \ + gnupg2 \ + wget \ + software-properties-common + +COPY docker/sgx.sources.list docker/ms.sources.list /etc/apt/sources.list.d/ +# ms and intel repos keep old packages around, +# however if they remove some of these in the future +# binary packages can be retrieved from github releases +RUN wget -qO - https://download.01.org/intel-sgx/sgx_repo/ubuntu/intel-sgx-deb.key | apt-key add - && \ + wget -qO - https://packages.microsoft.com/keys/microsoft.asc | apt-key add - && \ + apt-get update && \ + apt -y install \ + libsgx-ae-id-enclave=1.16.100.2-focal1 \ + libsgx-ae-pce=2.19.100.3-focal1 \ + libsgx-ae-qe3=1.16.100.2-focal1 \ + libsgx-dcap-ql=1.16.100.2-focal1 \ + libsgx-dcap-ql-dev=1.16.100.2-focal1 \ + libsgx-enclave-common=2.19.100.3-focal1 \ + libsgx-headers=2.19.100.3-focal1 \ + libsgx-pce-logic=1.16.100.2-focal1 \ + libsgx-qe3-logic=1.16.100.2-focal1 \ + libsgx-urts=2.19.100.3-focal1 \ + open-enclave=0.19.0 + +FROM public.ecr.aws/amazonlinux/amazonlinux@sha256:94e7183b0739140dbd5b639fb7600f0a2299cec5df8780c26d9cb409da5315a9 AS nsmbuild +ENV HOST_MACHINE=x86_64 +ENV RUST_VERSION=1.58.1 +ENV RUSTUP_HOME=/usr/local/rustup \ + CARGO_HOME=/usr/local/cargo \ + PATH=/usr/local/cargo/bin:$PATH + +RUN yum install -y gcc + +RUN set -eux; \ + curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs/ | sh -s -- --default-toolchain ${RUST_VERSION} -y ; \ + chmod -R a+w $RUSTUP_HOME $CARGO_HOME; \ + rustup --version; \ + cargo --version; \ + rustc --version + +COPY docker/aws-nitro-enclaves-nsm-api /build +COPY docker/aws-nitro.Cargo.lock /build/Cargo.lock +WORKDIR /build + +RUN set -eux; \ + (cd nsm-lib && cargo build --release --locked) +RUN ar mD target/release/libnsm.a $(ar t target/release/libnsm.a | env -u LANG LC_ALL=C sort) +COPY docker/check_hash.sh docker/sha256.* ./ +RUN ./check_hash.sh target/release/libnsm.a + +FROM base AS builder + +RUN mkdir /src && \ + apt-get update && \ + apt-get -y install \ + clang-11 \ + libssl-dev \ + gdb \ + libtool \ + bison \ + automake \ + flex \ + libcurl4 \ + pkg-config \ + make \ + unzip \ + git \ + gcc \ + libgtest-dev + +COPY docker/check_hash.sh docker/sha256.* ./ + +ARG PROTOBUF_PLATFORM=linux-x86_64 +ARG PROTOBUF_VERSION=21.8 +ARG PROTOBUF_BASE=protoc-${PROTOBUF_VERSION}-${PROTOBUF_PLATFORM} + +RUN wget https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOBUF_VERSION}/${PROTOBUF_BASE}.zip \ + && /bin/bash ./check_hash.sh ${PROTOBUF_BASE}.zip \ + && mkdir -p ${PROTOBUF_BASE} \ + && cd ${PROTOBUF_BASE} \ + && unzip -o ../${PROTOBUF_BASE}.zip \ + && cd .. \ + && mv ${PROTOBUF_BASE} /opt/protobuf + +ARG GOLANG_PLATFORM=linux-amd64 +ARG GOLANG_VERSION=1.20.2 +ARG GOLANG_TAR_GZ=go${GOLANG_VERSION}.${GOLANG_PLATFORM}.tar.gz + +RUN wget https://go.dev/dl/${GOLANG_TAR_GZ} \ + && /bin/bash ./check_hash.sh ${GOLANG_TAR_GZ} \ + && tar xzf ${GOLANG_TAR_GZ} \ + && mv go /opt/ + +ENV PATH="/opt/openenclave/bin:/opt/go/bin:/opt/protobuf/bin:${PATH}" +ENV GOROOT="/opt/go" +ENV GOBIN="/opt/go/bin" +ENV PKG_CONFIG_PATH="/opt/openenclave/share/pkgconfig" + +ARG PROTOC_GEN_GO_GITREV=6875c3d7242d1a3db910ce8a504f124cb840c23a +RUN go install google.golang.org/protobuf/cmd/protoc-gen-go@${PROTOC_GEN_GO_GITREV} +RUN echo "export PS1='buildenv: \w$ '" >> /etc/bash.bashrc + +# Set this after `go install` so we don't use the same cache as root. +ENV GOPATH="/src/.gopath" +ENV GOCACHE="/src/.gocache" + +WORKDIR /src +COPY --from=nsmbuild /build/target/release/libnsm.a /opt/nsm/libnsm.a +COPY --from=nsmbuild /build/target/release/nsm.h /opt/nsm/nsm.h + +CMD ["/bin/bash"] + +FROM builder AS build + +COPY . /src +RUN cd /src && make clean && make -j16 all enclave_releaser + +FROM base AS runner + +RUN apt-get update && apt-get install -y \ + libsgx-dcap-default-qpl=1.16.100.2-focal1 \ + libsgx-dcap-default-qpl-dev=1.16.100.2-focal1 +COPY docker/sgx_default_qcnl_azure.conf /etc/sgx_default_qcnl.conf +COPY --from=build /src/host/main /bin/svr2 +COPY --from=build /src/enclave/releases /enclaves +COPY --from=build /src/host/cmd/control/control /bin/svr2control + +ENTRYPOINT ["/bin/svr2"] diff --git a/docker/apt.conf b/docker/apt.conf new file mode 100644 index 0000000..8dc2c18 --- /dev/null +++ b/docker/apt.conf @@ -0,0 +1,15 @@ +Apt { + Architecture "amd64"; + Architectures "amd64"; +}; + +Acquire::Check-Valid-Until "false"; +Acquire::Languages "none"; +Binary::apt-get::Acquire::AllowInsecureRepositories "false"; + +APT::Install-Recommends "false"; + +// go easy on snapshot.debian.org +Acquire::http::Dl-Limit "10000"; +Acquire::https::Dl-Limit "10000"; +Acquire::Retries "5"; diff --git a/docker/aws-nitro-enclaves-nsm-api b/docker/aws-nitro-enclaves-nsm-api new file mode 160000 index 0000000..944562d --- /dev/null +++ b/docker/aws-nitro-enclaves-nsm-api @@ -0,0 +1 @@ +Subproject commit 944562dacce23dc947bea1df60b5dd3a51fb8c4f diff --git a/docker/aws-nitro.Cargo.lock b/docker/aws-nitro.Cargo.lock new file mode 100644 index 0000000..183c524 --- /dev/null +++ b/docker/aws-nitro.Cargo.lock @@ -0,0 +1,561 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "aws-nitro-enclaves-nsm-api" +version = "0.2.1" +dependencies = [ + "libc", + "log", + "nix 0.20.2", + "serde", + "serde_bytes", + "serde_cbor", +] + +[[package]] +name = "bitflags" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693" + +[[package]] +name = "cbindgen" +version = "0.24.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6358dedf60f4d9b8db43ad187391afe959746101346fe51bb978126bec61dfb" +dependencies = [ + "heck", + "indexmap", + "log", + "proc-macro2", + "quote", + "serde", + "serde_json", + "syn 1.0.109", + "tempfile", + "toml", +] + +[[package]] +name = "cc" +version = "1.0.79" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "errno" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50d6a0976c999d473fe89ad888d5a284e55366d9dc9038b1ba2aa15128c4afa0" +dependencies = [ + "errno-dragonfly", + "libc", + "windows-sys 0.45.0", +] + +[[package]] +name = "errno-dragonfly" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf" +dependencies = [ + "cc", + "libc", +] + +[[package]] +name = "fastrand" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be" +dependencies = [ + "instant", +] + +[[package]] +name = "half" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" + +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + +[[package]] +name = "hermit-abi" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee512640fe35acbfb4bb779db6f0d80704c2cacfa2e39b601ef3e3f47d1ae4c7" +dependencies = [ + "libc", +] + +[[package]] +name = "hermit-abi" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" + +[[package]] +name = "indexmap" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +dependencies = [ + "autocfg", + "hashbrown", +] + +[[package]] +name = "instant" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "io-lifetimes" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c66c74d2ae7e79a5a8f7ac924adbe38ee42a859c6539ad869eb51f0b52dc220" +dependencies = [ + "hermit-abi 0.3.1", + "libc", + "windows-sys 0.48.0", +] + +[[package]] +name = "itoa" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" + +[[package]] +name = "libc" +version = "0.2.141" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3304a64d199bb964be99741b7a14d26972741915b3649639149b2479bb46f4b5" + +[[package]] +name = "linux-raw-sys" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d59d8c75012853d2e872fb56bc8a2e53718e2cafe1a4c823143141c6d90c322f" + +[[package]] +name = "log" +version = "0.4.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "memoffset" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aa361d4faea93603064a027415f07bd8e1d5c88c9fbf68bf56a285428fd79ce" +dependencies = [ + "autocfg", +] + +[[package]] +name = "nix" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5e06129fb611568ef4e868c14b326274959aa70ff7776e9d55323531c374945" +dependencies = [ + "bitflags", + "cc", + "cfg-if", + "libc", + "memoffset", +] + +[[package]] +name = "nix" +version = "0.24.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa52e972a9a719cecb6864fb88568781eb706bac2cd1d4f04a648542dbf78069" +dependencies = [ + "bitflags", + "cfg-if", + "libc", + "memoffset", +] + +[[package]] +name = "nsm-lib" +version = "0.2.1" +dependencies = [ + "aws-nitro-enclaves-nsm-api", + "cbindgen", + "serde_bytes", +] + +[[package]] +name = "nsm-test" +version = "0.2.1" +dependencies = [ + "aws-nitro-enclaves-nsm-api", + "nix 0.20.2", + "nsm-lib", + "serde_bytes", + "serde_cbor", + "signal-hook", + "threadpool", + "vsock", +] + +[[package]] +name = "num_cpus" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fac9e2da13b5eb447a6ce3d392f23a29d8694bff781bf03a16cd9ac8697593b" +dependencies = [ + "hermit-abi 0.2.6", + "libc", +] + +[[package]] +name = "proc-macro2" +version = "1.0.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b63bdb0cd06f1f4dedf69b254734f9b45af66e4a031e42a7480257d9898b435" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "redox_syscall" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +dependencies = [ + "bitflags", +] + +[[package]] +name = "rustix" +version = "0.37.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1aef160324be24d31a62147fae491c14d2204a3865c7ca8c3b0d7f7bcb3ea635" +dependencies = [ + "bitflags", + "errno", + "io-lifetimes", + "libc", + "linux-raw-sys", + "windows-sys 0.48.0", +] + +[[package]] +name = "ryu" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" + +[[package]] +name = "serde" +version = "1.0.159" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c04e8343c3daeec41f58990b9d77068df31209f2af111e059e9fe9646693065" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_bytes" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "416bda436f9aab92e02c8e10d49a15ddd339cea90b6e340fe51ed97abb548294" +dependencies = [ + "serde", +] + +[[package]] +name = "serde_cbor" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5" +dependencies = [ + "half", + "serde", +] + +[[package]] +name = "serde_derive" +version = "1.0.159" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c614d17805b093df4b147b51339e7e44bf05ef59fba1e45d83500bcfb4d8585" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.13", +] + +[[package]] +name = "serde_json" +version = "1.0.95" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d721eca97ac802aa7777b701877c8004d950fc142651367300d21c1cc0194744" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "signal-hook" +version = "0.3.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "732768f1176d21d09e076c23a93123d40bba92d50c4058da34d45c8de8e682b9" +dependencies = [ + "libc", + "signal-hook-registry", +] + +[[package]] +name = "signal-hook-registry" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" +dependencies = [ + "libc", +] + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c9da457c5285ac1f936ebd076af6dac17a61cfe7826f2076b4d015cf47bc8ec" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tempfile" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9fbec84f381d5795b08656e4912bec604d162bff9291d6189a78f4c8ab87998" +dependencies = [ + "cfg-if", + "fastrand", + "redox_syscall", + "rustix", + "windows-sys 0.45.0", +] + +[[package]] +name = "threadpool" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d050e60b33d41c19108b32cea32164033a9013fe3b46cbd4457559bfbf77afaa" +dependencies = [ + "num_cpus", +] + +[[package]] +name = "toml" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4f7f0dd8d50a853a531c426359045b1998f04219d88799810762cd4ad314234" +dependencies = [ + "serde", +] + +[[package]] +name = "unicode-ident" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5464a87b239f13a63a501f2701565754bae92d243d4bb7eb12f6d57d2269bf4" + +[[package]] +name = "vsock" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c8e1df0bf1e1b28095c24564d1b90acae64ca69b097ed73896e342fa6649c57" +dependencies = [ + "libc", + "nix 0.24.3", +] + +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets 0.42.2", +] + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.0", +] + +[[package]] +name = "windows-targets" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + +[[package]] +name = "windows-targets" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5" +dependencies = [ + "windows_aarch64_gnullvm 0.48.0", + "windows_aarch64_msvc 0.48.0", + "windows_i686_gnu 0.48.0", + "windows_i686_msvc 0.48.0", + "windows_x86_64_gnu 0.48.0", + "windows_x86_64_gnullvm 0.48.0", + "windows_x86_64_msvc 0.48.0", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" + +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" + +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" diff --git a/docker/check_hash.sh b/docker/check_hash.sh new file mode 100755 index 0000000..a971bdc --- /dev/null +++ b/docker/check_hash.sh @@ -0,0 +1,8 @@ +#!/bin/bash +set -e +EXPECTED_HASH="$(cat sha256."$(basename "$1")")" +ACTUAL_HASH="$(sha256sum "$1")" +echo "Checking hash for '$1'" +echo "Expected: '$EXPECTED_HASH'" +echo "Actual: '$ACTUAL_HASH'" +exec [ "$EXPECTED_HASH" == "$ACTUAL_HASH" ] diff --git a/docker/ms.sources.list b/docker/ms.sources.list new file mode 100644 index 0000000..920de69 --- /dev/null +++ b/docker/ms.sources.list @@ -0,0 +1 @@ +deb [arch=amd64] https://packages.microsoft.com/ubuntu/20.04/prod focal main diff --git a/docker/sgx.sources.list b/docker/sgx.sources.list new file mode 100644 index 0000000..89649f4 --- /dev/null +++ b/docker/sgx.sources.list @@ -0,0 +1 @@ +deb [arch=amd64] https://download.01.org/intel-sgx/sgx_repo/ubuntu focal main diff --git a/docker/sgx_default_qcnl_azure.conf b/docker/sgx_default_qcnl_azure.conf new file mode 100644 index 0000000..a91c6eb --- /dev/null +++ b/docker/sgx_default_qcnl_azure.conf @@ -0,0 +1,29 @@ +{ + "pccs_url": "https://global.acccache.azure.net/sgx/certification/v3/", + + "use_secure_cert": false, + + "collateral_service": "https://pccs/sgx/certification/v3/", + + "pccs_api_version": "3.1", + + "retry_times": 6, + + "retry_delay": 5, + + "local_pck_url": "http://169.254.169.254/metadata/THIM/sgx/certification/v3/", + + "pck_cache_expire_hours": 48, + + "custom_request_options" : { + "get_cert" : { + "headers": { + "metadata": "true" + }, + "params": { + "api-version": "2021-07-22-preview" + } + } + } +} + diff --git a/docker/sha256.go1.20.2.linux-amd64.tar.gz b/docker/sha256.go1.20.2.linux-amd64.tar.gz new file mode 100644 index 0000000..9693c1b --- /dev/null +++ b/docker/sha256.go1.20.2.linux-amd64.tar.gz @@ -0,0 +1 @@ +4eaea32f59cde4dc635fbc42161031d13e1c780b87097f4b4234cfce671f1768 go1.20.2.linux-amd64.tar.gz diff --git a/docker/sha256.libnsm.a b/docker/sha256.libnsm.a new file mode 100644 index 0000000..4f17459 --- /dev/null +++ b/docker/sha256.libnsm.a @@ -0,0 +1 @@ +350d5fde8e139301aaf39a47509aa0fa0c9ced472d4ae30c45c5504b1ef45490 target/release/libnsm.a diff --git a/docker/sha256.protoc-21.8-linux-x86_64.zip b/docker/sha256.protoc-21.8-linux-x86_64.zip new file mode 100644 index 0000000..22c3202 --- /dev/null +++ b/docker/sha256.protoc-21.8-linux-x86_64.zip @@ -0,0 +1 @@ +f90d0dd59065fef94374745627336d622702b67f0319f96cee894d41a974d47a protoc-21.8-linux-x86_64.zip diff --git a/docker/sources.list b/docker/sources.list new file mode 100644 index 0000000..337ccde --- /dev/null +++ b/docker/sources.list @@ -0,0 +1,5 @@ +deb http://snapshot.debian.org/archive/debian/20220912T000000Z/ bullseye main +deb http://snapshot.debian.org/archive/debian/20220912T000000Z/ bullseye-updates main + +deb http://snapshot.debian.org/archive/debian/20220912T000000Z/ buster main +deb http://snapshot.debian.org/archive/debian/20220912T000000Z/ buster-updates main diff --git a/docs/Healing.md b/docs/Healing.md new file mode 100644 index 0000000..73ffb7a --- /dev/null +++ b/docs/Healing.md @@ -0,0 +1,98 @@ +# Healing + +When we talk about "healing" in SVR2, we're currently talking about membership +change in the Raft replica group. In SVR2, we break healing down into +the following sub-problems: + +- Remove old nodes when they're unable to serve (rebooted, etc) +- Add new nodes to replace removed nodes + +Removing of nodes is currently unimplemented, more on that later. + +## Adding new nodes + +A new SVR2 node that wants to become a replica within the Raft cluster +currently goes through the following state transitions to get it to a +serving state. These are currently driven by host-side requests, but +in near-future we hope to make the decision to promote replicas to +voting status an in-enclave decision. + +In short, a node starts up without any Raft state. It then decides +to follow one of two paths: + +- Start a new Raft group as the sole replica/leader. +- Join an existing Raft group by talking to some replica in that group. + +Starting a new Raft group is out of scope of this doc: it just does :). +Joining an existing group, though, is the primary mechanism by which new +nodes are added. We assume that we're running in an environment where +broken nodes are replaced (by shutting down the old node and starting up +a new one) as K8S and most other cloud provider workflows allow. In this +case, "adding a new node" is actually "starting a new node, and having +it request to join the group". + +Breaking this down in more detail, a node that wants to join a group +goes through a set of state transitions by talking to other nodes: + +1. Host tells the enclave about a single peer ID +1. Get information about the Raft group (group ID, other replicas, etc) +1. Replicate existing state (logs/database) up to a recent commit +1. Send a `request_membership` request to the leader +1. Send a `request_vote` request to the leader + +These steps are accomplished by calling enclave-to-enclave (e2e) +transactions (protos in `enclave/proto/e2e.proto`). + +### Host join request + +The host starts the join by sending a `HostToEnclaveRuequst.join_raft` call +to the enclave, with a PeerID it knows about that's part of the existing group. + +### Get information about Raft group + +The enclave calls the `e2e::TransactionRequest.get_raft` transaction on the one +peer ID it knows about (the one passed in by `join_raft`). This gives it the +`RaftGroupConfig` (immutable Raft configuration) and `raft.ReplicaGroup` +(current membership in the group). It then transitions to the next state. + +### Replicate existing state + +The enclave picks a random peer from among those in the `ReplicaGroup` +(it will eventually make a more interesting decision about which peer to talk +to), then makes a series of `e2e::TransactionRequest.replicate_state` +requests against that peer. These requests first pull in all logs from +the remote peer until the new node reaches the responder's commit index. +At this point, the new node will start to request and receive a combination +of any new logs committed since that first commit point and database state. +When it's read in the full keyspace of the database (applying as it goes +any newly-committed logs it recieves), it will be at a point where it has +all logs and all database state up to the latest committed index of the +responder. It then transitions to the next state. + +### Request join + +The enclave then requests to join the group as a non-voting member. +It sends an `e2e::TransactionRequest.request_raft_membership` to the +leader of the group (it actually sends it to all members, but should be +changed in the near future to target just the suspected leader). +If this request succeeds, it is now in a ReplicaGroup config on a +non-committed leader log. The leader will begin to treat it as a normal +non-voting member initially, including replicating to it via AppendEntries +any uncommitted logs and telling it when those logs commits. The node +stays in this state, watching its raft log, until it sees that a +ReplicaGroup log containing its PeerID has been committed. At this point, +it knows that it is now a member, and transitions its local state to +act as such. + +### Request vote + +This is another mechanism that's currently driven by the host, but should +probably become an automatic enclave function. After an enclave becomes +a non-voting member of the Raft group, the host can send a +`HostToEnclaveRuequst.request_voting` request to the enclave. This +instructs the enclave to send an `e2e::TransactionRequest.request_raft_voting` +call to its current leader. On success, the leader switches the replica's +voting status from non-voting to voting by writing a new ReplicaGroup with +the associated changes to its log. The requesting node (and all other +nodes in the Raft group) hear about this change via normal mechanisms for +ReplicaGroup change. diff --git a/docs/Messages.md b/docs/Messages.md new file mode 100644 index 0000000..ade0e1d --- /dev/null +++ b/docs/Messages.md @@ -0,0 +1,176 @@ +# Enclave Messages: The Enclave's Logical Interface +The SVR2 enclave interface defined in [svr2.edl](../../shared/svr2.edl) is +generic. It provides initialization and message passing functions that are +independent of the application logic. The _logical_ interface of the enclave +is defined by these messages and how the enclave responds to them. In what +follows we will think of the different messages that can be sent to the enclave +as RPCs and refer to them as "calls" or "commands". + +## Three Interfaces: Host, Peer, and Client +SVR2 enclaves interact with three different types of entities: the _host_ that +makes ECALLs and receives OCALLs from the enclave, other _peer_ enclaves, and +_clients_ that are using the service to store and recover secure values. + +The host interface includes a number of administrative commands (create or join +a replica group, get enclave status, tick the Raft timer, etc.). It also has +commands to forward wrapped peer or client requests. + +The peer interface includes Raft protocol messages, attestation updates, and as +a number of other "Enclave to Enclave [E2E] transactions" used to get +information about a replica group, transfer database state, and join a replica +group. + +The client interface is the raison d'être for SVR2. It allows clients to backup, +restore, or delete a secure value. Everything else in this system is here to +ensure that this is done securely and reliably. + +We will use this abstraction to organize this document, but it does *not* align +perfectly with the organization of the code. The code organization reflects +important implementation details as follows: + +* All messages to the enclave are sent in an `UntrustedMessage` + ([shared/proto/msgs.proto](../../shared/proto/msgs.proto)). These + may be direct commands or forwarded messages from peers or clients. +* Host calls that will not trigger response messages are sent as a simple + `UntrustedMessage`. + We will call these _synchronous host calls_. +* Host calls that MAY trigger response messages are sent as a + `HostToEnclaveRequest` inside an `UntrustedMessage`. It is important to note + that *all* client requests are sent this way. +* `HostToEnclaveRequest`s are further subdivided into administrative requests + and requests on behalf of clients. Client requests may be Noise encrypted + (backup, restore, delete) or unencrypted (create new client, create backup). + Encrypted client messages are defined in + ([shared/proto/msgs.proto](../../shared/proto/msgs.proto)). Unencrypted ones + are defined as submessages of `HostToEnclaveRequest` in + ([shared/proto/msgs.proto](../../shared/proto/msgs.proto)). +* Peer calls are all sent as `PeerMessage` messages inside an + `UntrustedMessage`. These messages contain raw bytes that either hold handshake + information or a Noise encrypted `EnclaveToEnclaveMessage`. These messages + are defined in [enclave/proto/e2e.proto](../proto/e2e.proto) + +There is another important property that we will note on all of the calls we +describe: some calls require that a new Raft log entry be accepted and committed +by this node's replica group in order to complete, others do not. We will say that +the calls that succeed or fail based on whether a log entry was successfully +committed "require Raft consensus". + +## The Host Interface + +### Synchronous Calls +There are two messages the host can send to the enclave that act as +synchronous calls - once the ECALL returns the action is complete. No messages +will be sent from the enclave in response to these calls. None of these require Raft +consensus. They are: + +1. `TimerTick` passes a unix timestamp that causes the enclave's to update its + internal time (which is used to obtain a consensus `group_time` with its peers), + then perform a `RaftStep`. +1. `ResetPeer` lets this Raft instance know that the given peer ID + may have lost some of the messages we sent to it previously. + +### Asynchronous Calls +All other calls from the host may cause the enclave to send response messages +that must be handled asynchronously. These are all sent as a +`HostToEnclaveRequest` inside an `UntrustedMessage`. These calls include: + +1. **Reconfigure** (`enclaveconfig.EnclaveConfig`) Reconfigure the replica with + new host-supplied configuration. +1. **GetEnclaveStatus** (`bool`) Retrieves basic + information about the status of a replica. Has more detail if the + replica is a leader. +1. **DeleteBackup** (`DeleteBackupRequest` - _requires consensus_) Used by host + to delete a backup, e.g., when the account is deleted. +1. **CreateNewRaftGroup** (`RaftConfig`) Request that we create a new raft group + from scratch, setting ourselves as the sole member and leader. This should be + done to seed a new Raft, after which we should requst `JoinRaft` instead. +1. **JoinRaft** (`JoinRaftRequest` - _requires consensus_) This tells the + enclave to join a particular Replica group. This call requires that the + target raft group be up and running. Raft joining is a + multi-step process described in detail in [Healing.md](./Healing.md). In + this process there will be an enclave-to-enclave call that creates a new + Raft configuration. This change must requires consensus of the existing + voting members. If successful the enclave will be a non-voting, + up-to-date member of the specified Raft. +1. **PingPeer** (`EnclavePeer`) Tells an enclave to check connectivity with + another peer. +1. **RequestVoting** (`bool` - _requires consensus_) Tells an enclave that + is already a member of a replica group to request voting status. This + requires a new Raft configuration to be accepted by a majority of the + voting members of the *new* configuration. +1. **RequestMetrics** Get all metrics and gauges collected by the enclave. +1. **RefreshAttestation** Refresh attestations for peer and client connections. +1. **SetLogLevel** Sets the enclave's logging level with an `::svr2::EnclaveLogLevel` + enum. These enum values match Open Enclave's [oe_log_level_t](https://github.com/openenclave/openenclave/blob/master/include/openenclave/log.h). +1. **RelinquishLeadership** (`bool` - _requires consensus_) If we are the Raft + leader, give it up and attempt to pass leadership to an up-to-date peer without + waiting for the election timers. +1. **RequestRemoval** (`bool`- _requires consensus_) Request that this replica be removed from the Raft + group. +1. **Hashes** (`bool`)Compute and return to the host a hash of the current DB. + +## The Peer Interface +Peer to peer calls fall into three categories: +1. Raft messages +1. Connectivity messages +1. E2E Transactions + +### Raft Messages +The Raft protocol messages defined in [enclave/raft.proto](../proto/raft.proto) +closely follow the Raft protocol defined in +[Ongaro's thesis](https://web.stanford.edu/~ouster/cgi-bin/papers/OngaroPhD.pdf). + +### Connectivity Messages +These messages are defined in [enclave/proto/e2e.proto](../proto/e2e.proto). +1. **Connect** (`e2e.ConnectRequest`) Sends attestation and handshake + information to initiate a connection with a peer. The response to this + call contains attestation and handshake information for the called + enclave. +1. **AttestationUpdate** (`Attestation`) sends a new attestation to a peer so + that peers can ensure their long-term connection with another enclave is + still secure. + +### Enclave to Enclave (E2E) Transactions + +These messages are defined in [enclave/proto/e2e.proto](../proto/e2e.proto). + +1. **GetRaft** (`e2e.GetRaftRequest`) Gets Raft membership information so that + the enclave can initiate the joining process. +1. **ReplicateState** (`e2e.ReplicateStateRequest`) Requests a chunk of + database state from a peer. This can include log messages and database rows. +1. **ReplicateStatePush** (`e2e.ReplicateStatePush`) +1. **RaftMembershipRequest** (`bool` - _requires consensus_) Request + to become a non-voting member of a replica group by setting this `true`. + Assumes that the calling peer is loaded and up to date. +1. **RaftVotingRequest** (`bool` - _requires consensus_) Request + to become a voting member of a replica group by setting this `true`. + assumes that the calling peer is a non-voting member of the group. +1. **RaftWrite** (`bytes` - _requires consensus_) Forward a log entry to + Raft leader to be added to the log. +1. **Ping** (`bool`) Request from a peer for simple acknowledgement to confirm + the connection to the requesting peer's host. +1. **NewTimestampUnixSecs** (`uint64`) Contains the sending peers timestamp. + Recipient will update peer and group times. +1. **RaftRemovalRequest** (`bool`) Creates a new replica group configuration without + the requesting peer in it, and submits this change to the new voting peers + for committment. + + +## The Client Interface +The `client.*` messages are defined in +[client.proto](../../shared/proto/client.proto). These are sent over the Noise +encrypted channel between the client and the enclave, wrapped in an +`ExistingClientRequest` submessage of a `HostToEnclaveRequest`. + + +1. **NewClient** (`NewClientRequest`) +1. **CreateBackup** (`CreateBackupRequest` - _requires consensus_) Creates an + empty backup row in the database. +1. **Backup** (`client.BackupRequest` - _requires consensus_) Stores a new value + and resets the number of allowed tries for a given backup ID. +1. **RestoreBackup** (`client.RestoreRequest` - _requires consensus_) Presents an + authorization token/PIN for a backup ID. If the token is correct, the secure + value is retrieved from the database and sent to the client over the Noise + connection. If it is incorrect the number of allowed tries is decremented. + If no more tries remain, the database row is deleted. +1. **DeleteBackup - client request** (`client.DeleteRequest` - _requires consensus_) diff --git a/docs/svr3spec/.gitignore b/docs/svr3spec/.gitignore new file mode 100644 index 0000000..7537d53 --- /dev/null +++ b/docs/svr3spec/.gitignore @@ -0,0 +1,5 @@ +svr3.aux +svr3.bbl +svr3.blg +svr3.log +svr3.out \ No newline at end of file diff --git a/docs/svr3spec/README.md b/docs/svr3spec/README.md new file mode 100644 index 0000000..7a5303a --- /dev/null +++ b/docs/svr3spec/README.md @@ -0,0 +1,11 @@ +## Building the PDF + +To build the pdf from source, you will need to install pdflatex and bibtex - the [TeXLive](https://www.tug.org/texlive/) distribution is probably the simplest way to do this. Alternatively tou can use an online system like [OVerleaf](https://overleaf.com) that will take care of most of the LaTeX related headaches for you. + +With these installed, run the following commands: +``` +pdflatex svr3.tex # produces initial pdf and computes references needed in svr3.aux +bibtex svr3 # builds bibliography +pdflatex svr3.tex # incorporates bilbiography into the pdf +``` +If prompted for input during the last run of pdflatex, press "enter" to continue. diff --git a/docs/svr3spec/svr3.bib b/docs/svr3spec/svr3.bib new file mode 100644 index 0000000..38bf07f --- /dev/null +++ b/docs/svr3spec/svr3.bib @@ -0,0 +1,40 @@ + +@misc{jkkx, + author = {Stanislaw Jarecki and Aggelos Kiayias and Hugo Krawczyk and Jiayu Xu}, + title = {Highly-Efficient and Composable Password-Protected Secret Sharing (Or: How to Protect Your Bitcoin Wallet Online)}, + howpublished = {Cryptology ePrint Archive, Paper 2016/144}, + year = {2016}, + note = {\url{https://eprint.iacr.org/2016/144}}, + url = {https://eprint.iacr.org/2016/144} +} + +@misc{poprf, + author = {Nirvan Tyagi and Sofı́a Celi and Thomas Ristenpart and Nick Sullivan and Stefano Tessaro and Christopher A. Wood}, + title = {A Fast and Simple Partially Oblivious PRF, with Applications}, + howpublished = {Cryptology ePrint Archive, Paper 2021/864}, + year = {2021}, + note = {\url{https://eprint.iacr.org/2021/864}}, + url = {https://eprint.iacr.org/2021/864} +} + +@book{bonehshoup, + author = {Dan Boneh and Victor Shoup}, + title = {A Graduate Course in Applied Cryptography}, + note = {\url{https://toc.cryptobook.us/book.pdf}}, + url = {https://toc.cryptobook.us/book.pdf} +} +@misc{2hashdh, + author = {Stanislaw Jarecki and Aggelos Kiayias and Hugo Krawczyk}, + title = {Round-Optimal Password-Protected Secret Sharing and T-PAKE in the Password-Only Model}, + howpublished = {Cryptology ePrint Archive, Paper 2014/650}, + year = {2014}, + note = {\url{https://eprint.iacr.org/2014/650}}, + url = {https://eprint.iacr.org/2014/650} +} + +@misc {ietf-oprf, + author = {A. Davidson, A. Faz-Hernandez, N. Sullivan, C. A. Wood}, + title = {Oblivious Pseudorandom Functions (OPRFs) using Prime-Order Groups}, + url = {https://www.ietf.org/id/draft-irtf-cfrg-voprf-21.html#name-informative-references-7}, + note = {\url{https://www.ietf.org/id/draft-irtf-cfrg-voprf-21.html}} +} \ No newline at end of file diff --git a/docs/svr3spec/svr3.pdf b/docs/svr3spec/svr3.pdf new file mode 100644 index 0000000..b8d9263 Binary files /dev/null and b/docs/svr3spec/svr3.pdf differ diff --git a/docs/svr3spec/svr3.tex b/docs/svr3spec/svr3.tex new file mode 100644 index 0000000..2ef59ba --- /dev/null +++ b/docs/svr3spec/svr3.tex @@ -0,0 +1,349 @@ +\documentclass{article} + +% Language setting Replace `english' with e.g. `spanish' to change the document +% language +\usepackage[english]{babel} + +% Set page size and margins Replace `letterpaper' with `a4paper' for UK/EU +% standard size +\usepackage[letterpaper,top=2cm,bottom=2cm,left=3cm,right=3cm,marginparwidth=1.75cm]{geometry} + +% Useful packages +\usepackage{amsmath} +\usepackage{graphicx} +\usepackage[colorlinks=true, allcolors=blue]{hyperref} +\usepackage[ n, % or lambda +advantage, operators, sets, adversary, landau , probability, notions, logic, ff, +mm, primitives, events, complexity, oracles, asymptotics, keys]{cryptocode} +%% Primitives +\newcommand{\OPRF}{\pcalgostyle{OPRF}} +\newcommand{\POPRF}{\pcalgostyle{POPRF}} +\newcommand{\VOPRF}{\pcalgostyle{VOPRF}} + +\newcommand{\Blind}{\pcalgostyle{Blind}} +\newcommand{\BlindEvaluate}{\pcalgostyle{BlindEvaluate}} +\newcommand{\BlindEvaluateForClient}{\pcalgostyle{BlindEvaluateForClient}} +\newcommand{\Finalize}{\pcalgostyle{Finalize}} + +\newcommand{\PPSSStore}{\pcalgostyle{PPSSStore}} +\newcommand{\PPSSRecover}{\pcalgostyle{PPSSRecover}} + +\newcommand{\ServerCreateOPRFVersion}{\pcalgostyle{ServerCreateOPRFVersion}} + + +%% Hashes +\newcommand{\HashToPoint}{\pcalgostyle{HashToPoint}} +\newcommand{\HashToScalar}{\pcalgostyle{HashToScalar}} +\newcommand{\HashToField}{\pcalgostyle{HashToField}} +\newcommand{\EncodeToField}{\pcalgostyle{EncodeToField}} + +%% Variables +\newcommand{\oprfinput}{\pcalgostyle{oprf\_input}} +\newcommand{\oprfkeys}{\pcalgostyle{oprf\_keys}} +\newcommand{\usage}{\pcalgostyle{usage}} +\newcommand{\usagecount}{\pcalgostyle{usage\_count}} +\newcommand{\maxuses}{\pcalgostyle{max\_uses}} +\newcommand{\blind}{\pcalgostyle{blind}} +\newcommand{\blindedElement}{\pcalgostyle{blindedElement}} +\newcommand{\evaluatedElement}{\pcalgostyle{evaluatedElement}} + +\newcommand{\clientstate}{\pcalgostyle{client\_state}} +\newcommand{\serverstate}{\pcalgostyle{server\_state}} +\newcommand{\clientid}{\pcalgostyle{client\_id}} +\newcommand{\client}{\pcalgostyle{client}} +\newcommand{\server}{\pcalgostyle{server}} +\newcommand{\servers}{\pcalgostyle{servers}} +\newcommand{\name}{\pcalgostyle{name}} +\newcommand{\context}{\pcalgostyle{context}} +\newcommand{\return}{\ensuremath{\mathbf{return}\ }} + + +\title{DRAFT Guess Limited Password Protected Secret Sharing Proposal} +\author{Rolfe Schmidt} + + +\begin{document} +\maketitle + +\section{Overview} + +This is a protocol for {\em guess-limited password-based secure value recovery}. +It allows clients to interact with servers to securely reconstruct a secret +using a password while providing protection against both offline and online +dictionary attacks - even in the event of server compromise. It protects against +online dictionary attacks through guess limiting: after a configured number of +failed reconstruction attempts, the secure value becomes unrecoverable. + +\subsection{Outline} + +This protocol is a variation of the PPSS protocol of \cite{jkkx} implemented +with a {\em usage limited} version of the standards track 2HashDH $\OPRF$ of +\cite{2hashdh} as specified in \cite{ietf-oprf} that can be used safely with +smaller curves like Ristretto255. + +After covering notation in section \ref{sec:notation} we present an augmentation +of the standards track $\OPRF$ of \cite{ietf-oprf} in section \ref{sec:oprf} +that has servers generate per-client $\OPRF$ keys, enforces strict usage limits +on these keys, and allows clients to rotate their keys to avoid running into +usage limits. + +In section \ref{sec:ppss} we use this usage limited $\OPRF$ to construct a +secure PPSS. This protocol is close to that of \cite{jkkx}, but does not mandate +storage of masked shares on servers and eliminates the share commitment storage +on servers. We discuss ways to obtain robustness in section \ref{sec:robustness}. + +Importantly, we observe that if the underlying $\OPRF$ limits clients to +$\maxuses$ per key, then against a $(t,N)$ threshold scheme an attacker will be +limited to $\lfloor \frac{N}{t+1}\maxuses\rfloor$ password guess attempts before +the secret becomes unrecoverable. Thus our PPSS is {\em guess limited}. We also +show how keys can be deleted from the server to offer a form of forward security +in case of server compromise. + + +\section{Notation} +\label{sec:notation} +{\bf Algebraic objects.} This protocol will use a prime order cyclic group, +$\GG$ among those specified in \cite{ietf-oprf}. Since key use will be limited, +we can take $\GG$ to be Ristretto255. We +denote the order of $\GG$ by $q$ and thus denote the set of scalars for $\GG$ by +$\ZZ_{q}$. Group elements will be denoted by capital Latin letters, e.g. $A, B, +C, \ldots$. Scalars will be denoted by lower case Latin letters, e.g. $a, b, +c,\ldots$. $G$ denotes a public generator of $\GG$. Scalar multiplication will +usually be denoted without a symbol - $aG$ - but in places the infix operator +$*$ will be used for clarity, as in $\sk_{oprf}*G$. + +Secret sharing will be performed using polynomials over a finite field, $\FF$, +that is not related to the group $\GG$. + +{\bf Domain separation.} Throughout the protocol we will use $\context$ to +denote a domain separation prefix unique to the application performing the +protocol. + +{\bf Server and client state.} Each server will have state information captured +in the variable \serverstate. The public part of this state is available in the +variable \server. + +Similarly, each client will have persistent information captured in the variable +\clientstate, and the public part of this state will be accessible through the +variable \client. + +{\bf Function parameters.} We will use a number of functions associated with +$\GG$ and $\FF$ which we consider as protocol parameters. In an instantiation of +the protocol the parameters will be identified in the \context string. The +function parameters are: +\begin{itemize} + \item All parameters for the $\OPRF$ specified in \cite{ietf-oprf} + \item $\HashToField: \bin^{*} \rightarrow \FF$ +\end{itemize} + + + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +%% +%% OPRF +%% + +\section{ The $\OPRF$} +\label{sec:oprf} +The PPSS protocol relies on the verifiable 2HashDH $\OPRF$ of \cite{2hashdh} as +specified in the IRTF draft standard \cite{ietf-oprf}. We describe the protocol +using the $\OPRF$ mode of the standard but in section \ref{sec:robustness} note +situations where follow up calls to the $\VOPRF$ mode can be used to ensure +robustness. + +The IRTF standard specifies the following functions: +\begin{enumerate} + \item $(\blind, \blindedElement) \leftarrow \Blind(\oprfinput)$: Used by a + client to prepare the $\OPRF$ input to be sent to the server. + \item $\evaluatedElement \leftarrow \BlindEvaluate(\sk, + \blindedElement)$: Executed by the server to evaluate the $\OPRF$ + parameterized by the server-secret key $\sk$ on $\blindedElement$ to compute + a blinded output. + \item $v \leftarrow \Finalize(\oprfinput, \blind, \evaluatedElement, + \blindedElement)$: Takes the values returned by $\BlindEvaluate$ + along with original input, and values returned by $\Blind$ to compute the + final $\prf$ value, $v$. +\end{enumerate} + +\subsection{Usage Limited Evaluation} +The server for our protocol adds to these functions in two ways: it uses a +random per-client $\OPRF$ key stored in the dictionary $\serverstate.\oprfkeys$ +and it enforces a usage limit. Each $\OPRF$ key can only be used a fixed number +of times. Setting and rotation of these keys is discussed in +\ref{sec:versioning}. This is done with the function $\BlindEvaluateForClient$: + + + +\procedureblock[linenumbering]{$\BlindEvaluateForClient(\serverstate, \clientid, +\blindedElement)$}{ \usagecount \leftarrow \serverstate.\usagecount[\clientid] +\\ +\serverstate.\usagecount[\clientid] \leftarrow \usagecount + 1 \\ +\pcif \usagecount \geq \serverstate.\maxuses: \\ +\t \pcreturn \perp \\ +(\sk, \pk) \leftarrow \serverstate.\oprfkeys[\clientid] \pccomment{The client +MAY obtain the \ensuremath{\pk} corresponding to \ensuremath{\sk} at +registration} \\ +\pcreturn \BlindEvaluate(\sk, \blindedElement) } + + + +\subsection{$\OPRF$ Key Creation and Versioning} +\label{sec:versioning} +As noted in the previous section, $\OPRF$ keys are created per-client and each +key is strictly limited to a fixed number of uses. The usage limitation has two +useful purposes. First, since the security of the $\OPRF$ is based on the +one-more Diffie Hellman assumption, the security of a key used for $Q$ queries +is reduced by $\log(Q)/2$ bits (see, e.g., +\href{https://www.ietf.org/id/draft-irtf-cfrg-voprf-21.html#section-7.2.3}{7.2.3 +of the IRTF draft}). So, for example, by limiting key usage to no more than 16 +queries we only lose 2 bits of security and can safely use a group like +Ristretto255. Second, this limit enforcement will be the basis of the guess +limiting in the PPSS described in section \ref{sec:ppss}. + +Clients in our PPSS will need to reconstruct their secret an unlimited number of +times, though. To do this, upon successful reconstruction the client will create +a new version of their $\OPRF$ key. This new version will be constructed +with the function $\ServerCreateOPRFVersion$, which creates a new key pair, +stores it indexed by the client's identifier, clears the usage count, and +evaluates the $\OPRF$ with the new key on a blinded element: + + +\procedureblock[linenumbering]{$\ServerCreateOPRFVersion(\serverstate, +\clientid, \blindedElement)$}{ \sk \sample \ZZ_q \\ +\pk \leftarrow kG \\ +\serverstate.\oprfkeys[\clientid] \leftarrow (\sk, \pk)\\ +\serverstate.\usagecount[\clientid] \leftarrow 0 \\ +\evaluatedElement \leftarrow \BlindEvaluate(\sk, \blindedElement) \\ +\pcreturn (\evaluatedElement, \pk) \pccomment{Return the public key in case \nizk\ proof is needed later} +} + +\subsection{A Note About $\POPRF$ Mode} +It is tempting to use the $\POPRF$ mode introduced in \cite{poprf} rather than +generating client specific keys. If usage limitation were not a requirement this +would have a clear advantage - the server state would be no more than one secret +$\OPRF$ scalar. However, once we introduce the need for key usage limits and +key rotation this advantage disappears. Usage limitation requires storage of +per-client state. Key rotation requires the use of a nonce or a counter for each +client, effectively requiring the same storage as the proposed per-client key +solution. + +\section{A Guess Limited PPSS from the \OPRF} +\label{sec:ppss} +With these primitives in place we define the PPSS scheme with the functions +$\PPSSStore$ and $\PPSSRecover$. The idea is simple. To create a $(t,N)$ +threshold PPSS to store a secret $s$ with $N$ servers we +\begin{enumerate} + \item Create a degree $t$ polynomial $f\in\FF[x]$ with $s$ as the leading + coefficient, all other coefficients random. + \item Create a share for each server: $s_i = f(x_i)$ where $x_i = + \HashToField(\server_{i}.id)$ + \item Use the $\OPRF$ values to mask the shares: $m_i = s_i + + \server_i.\OPRF(\clientid, pwd)$. + \item Store the values $m_i$ somewhere reliable, but confidentiality is not + important. + \item To reconstruct, simply call $(t+1)$ or more servers to get their + $\OPRF$ values and use these to unmask the shares: $s_i = m_i - + \server_i.\OPRF(\clientid, pwd)$. + \item These shares can now be used to reconstruct the secret $s$. + \item Upon successful reconstruction the client can create new key versions + on all servers, refresh their guess counts, and create new masked shares. + All of this can be done without changing the password or master secret. +\end{enumerate} + +In the following $\servers$ is a set of $N$ $\server$ objects, $\mathbf{e}$ +is a dictionary that will store masked shares of a secret $s$, and $\mathbf{pks}$ +is a dictionary that stores server $\OPRF$ public keys. +\procedureblock[linenumbering]{$\PPSSStore(\clientstate, \servers, t, pwd, s)$}{ +r \concat K \leftarrow \hash(\context \concat ``keygen", s) \\ +\forall i \in [0,t-1] : f_i \sample \FF \\ +f_{t} \leftarrow \EncodeToField(s) \\ +\pcfor \server \in \servers: \\ +\t \oprfinput \leftarrow \context \concat \server.id \concat pwd \\ +\t x \leftarrow \HashToField(\server.id) \\ +\t y \leftarrow \sum_{i=0}^{t} f_i x^i \\ +\t (\blind, \blindedElement) \leftarrow \Blind(\oprfinput) \\ +\t (\evaluatedElement, \pk) \leftarrow +\server.\ServerCreateOPRFVersion(\clientstate.id, \blindedElement) \\ +\t \rho \leftarrow \Finalize(\oprfinput, \blind, \evaluatedElement,\blindedElement) \\ +\pccomment{ \ensuremath{\mathbf{e}} and \ensuremath{\mathbf{pks}} should be +stored somewhere reliable, but confidentiality is not needed} \\ +\t \mathbf{s}[x] \leftarrow y \\ +\t \clientstate.\mathbf{e}[x] \leftarrow y + \rho \\ +\t \clientstate.\mathbf{pks}[\server.id] \leftarrow \pk \\ +\clientstate.C \leftarrow \hash(\context\concat ``commitment", pwd, \clientstate.\mathbf{e}, \mathbf{s}, r) \\ +\pcreturn K +} + +\procedureblock[linenumbering]{$\PPSSRecover(\clientstate, \servers, t, pwd)$}{ +\text{Choose } \mathcal{R} \subset \servers, |\mathcal{R}| > t \\ +pairs \leftarrow \{\} \\ +\pcfor \server \in \mathcal{R}: \\ +\t \oprfinput \leftarrow \context \concat \server.id \concat pwd \\ +\t x \leftarrow \HashToField(\server.id) \\ +\t (\blind, \blindedElement) \leftarrow \Blind(\oprfinput) \\ +\t (\evaluatedElement, \pk) \leftarrow +\server.\BlindEvaluateForClient(\clientstate.id, \blindedElement) \\ +\t r \leftarrow \Finalize(\oprfinput, \blind, \evaluatedElement,\blindedElement) \\ +\t y \leftarrow \clientstate.\mathbf{m}[x] - r \\ +\t \mathbf{s}[x] \leftarrow y \\ +\t pairs \leftarrow pairs \cup \{(x,y)\} \\ +(f_{t}, \ldots, f_0) \leftarrow \pcalgostyle{Interpolate}_{\FF}(pairs) \\ +s \leftarrow f_{t} \\ +r \concat K \leftarrow \hash(\context \concat ``keygen", s) \\ +C \leftarrow \hash(\context\concat ``commitment", pwd, \clientstate.\mathbf{e}, \mathbf{s}, r) \\ +\pcif C \neq \clientstate.C: \\ +\t \pcreturn \perp \\ +\pcelse \\ +\t \PPSSStore(\clientstate, servers,t,pwd, f_{t}) \pccomment{store the secret again +to reset keys, counters, and shares} \\ +\t \pcreturn K } + +\subsection{Usage Limits on the $\OPRF$ lead to Guess Limits on the PPSS} +Now we can see how the usage limit we enforce on the $\OPRF$ naturally creates a +guess limit on the PPSS that provides protection against online dictionary +attacks. Consider the scenario where a client has constructed a $(t,N)$-sharing +scheme to protect a secret $s$ with password $pwd$ using $\PPSSStore$. Now an +attacker trying to guess the password and recover the secret faces the following +fact: each password guess requires using $t+1$ $\OPRF$ calls, and only +$\maxuses$ are possible on each of the $N$ servers. Thus the attacker has no +more than $\lfloor\frac{N}{t+1}\maxuses\rfloor$ guesses before the secret +becomes unrecoverable. + +\subsection{Deleting Keys} +\label{sec:deleting} +A client can protect themselves from future server compromise by deleting keys +from the server. This can be done by simply calling $\ServerCreateOPRFVersion$ +with arbitrary $\oprfinput$ and discarding the result. For the user to have +confidence that the keys were in fact deleted - now and during each $\PPSSStore$ +call - server functions can be executed in an attested, confidential TEE. + +Additionally, in the case that a TEE based server is being retired it can +produce an attested certificate of secret deletion. If a client has confidence +that their secrets have, in fact, been deleted from a server then they know that +their $(t,N)$ threshold scheme has become a $(t,N-1)$ scheme and they can safely +add a new server. + +\section{Robustness} +\label{sec:robustness} +Unlike \cite{jkkx} we do not store the commitment, $C$, and the masked shares, +$\mathbf{e}$, on each server. Instead we will have clients store $\clientstate$, +which includes both both these values, in a reliable place as suggested in \cite{2hashdh}. +We then rely on subset testing or follow-up $\VOPRF$ calls to +detect incorrect servers. + +The sole reason the server $\pk$ values are stored by the client in the protocol +above is to allow follow-up $\nizk$ proof verification if the $\VOPRF$ mode +is used for robustness. If only subset testing will be used (e.g. for small +values of $N$) then these public keys do not need to be stored. + +\section{Acknowledgements} +We would like to thank Mark Johnson for helping to develop an earlier version of +this protocol and Trevor Perrin for important feedback and pointers to the +literature. We also thank Emma Dautermann, Vivian Fang, and Raluca Ada Popa for +discussion that led to significant design decisions and simplifications of this +protocol. + +\bibliographystyle{alpha} +\bibliography{svr3} + +\end{document} \ No newline at end of file diff --git a/enclave/.gitignore b/enclave/.gitignore new file mode 100644 index 0000000..bdd1406 --- /dev/null +++ b/enclave/.gitignore @@ -0,0 +1,4 @@ +build/ +.testdepends +*.pem +noise-c/ diff --git a/enclave/Makefile b/enclave/Makefile new file mode 100644 index 0000000..7f24f05 --- /dev/null +++ b/enclave/Makefile @@ -0,0 +1,232 @@ +all: test build sign + +MAKEFILTER=| (grep --line-buffered -v '^make\[' || true) + +include Makefile.base + +.PHONY: all build sign clean sign protos validate generatem + +build: build/enclave.bin build/enclave.nsm + +sign: build/enclave.signed build/enclave.test build/enclave.small + +PROTO_FILES= \ + $(patsubst ../shared/proto/%.proto,build/proto/%.pb.cc,$(wildcard ../shared/proto/*.proto)) \ + $(patsubst ../shared/proto/%.proto,build/proto/%.pb.h,$(wildcard ../shared/proto/*.proto)) \ + $(patsubst proto/%.proto,build/proto/%.pb.cc,$(wildcard proto/*.proto)) \ + $(patsubst proto/%.proto,build/proto/%.pb.h,$(wildcard proto/*.proto)) \ +## PROTO_FILES +protos: $(PROTO_FILES) + +build/proto: + $(QUIET) echo -e "MKDIR $@" + $(QUIET) mkdir -p $@ +build/proto/%.pb.h build/proto/%.pb.cc: proto/%.proto | build/proto + $(QUIET) echo -e "PROTO\t$^" + $(QUIET) protoc --proto_path=../shared/proto --proto_path=proto --cpp_out=build/proto $^ +build/proto/%.pb.h build/proto/%.pb.cc: ../shared/proto/%.proto | build/proto + $(QUIET) echo -e "PROTO\t$^" + $(QUIET) protoc --proto_path=../shared/proto --cpp_out=build/proto $^ + +build/gtest/TEST.a: + $(QUIET) $(MAKE) -f Makefile.subdir DIR=gtest ENV=TEST ADDITIONAL_CFLAGS="-I$(CURDIR)/googletest/googletest" $(MAKEFILTER) + +build/noise-c/TEST.a: build/libsodium/TEST.a + $(QUIET) echo -e "BUILD\t$@" + $(QUIET) mkdir -p $(@D) + $(QUIET) (cd noise-c && \ + ./autogen.sh && \ + libsodium_CFLAGS=-I$$PWD/../build/libsodium/TEST.a.dir/include/ libsodium_LIBS=$$PWD/../build/libsodium/TEST.a \ + CC=$(CC) CFLAGS="$(TEST_CFLAGS) -I$(shell ./find_header.sh $(CC) immintrin.h)" ./configure --with-libsodium && \ + $(MAKE) clean && \ + $(MAKE)) $(QUIET_OUT) + $(QUIET) cp noise-c/src/protocol/libnoiseprotocol.a $@ + $(QUIET) echo -e "BUILT\t$@" +build/noise-c/SGX.a: build/libsodium/SGX.a | build/noise-c/TEST.a + $(QUIET) echo -e "BUILD\t$@" + $(QUIET) mkdir -p $(@D) + $(QUIET) (cd noise-c && \ + ./autogen.sh && \ + libsodium_CFLAGS=-I$$PWD/../build/libsodium/SGX.a.dir/include/ libsodium_LIBS=$$PWD/../build/libsodium/SGX.a \ + CC=$(CC) CFLAGS="$(SGX_CFLAGS) -I$(shell ./find_header.sh $(CC) immintrin.h)" ./configure --with-libsodium && \ + $(MAKE) clean && \ + $(MAKE)) $(QUIET_OUT) + $(QUIET) cp noise-c/src/protocol/libnoiseprotocol.a $@ + $(QUIET) echo -e "BUILT\t$@" +build/noise-c/NSM.a: build/libsodium/NSM.a | build/noise-c/SGX.a + $(QUIET) echo -e "BUILD\t$@" + $(QUIET) mkdir -p $(@D) + $(QUIET) (cd noise-c && \ + ./autogen.sh && \ + libsodium_CFLAGS=-I$$PWD/../build/libsodium/NSM.a.dir/include/ libsodium_LIBS=$$PWD/../build/libsodium/NSM.a \ + CC=$(CC) CFLAGS="$(NSM_CFLAGS) -I$(shell ./find_header.sh $(CC) immintrin.h)" ./configure --with-libsodium && \ + $(MAKE) clean && \ + $(MAKE)) $(QUIET_OUT) + $(QUIET) cp noise-c/src/protocol/libnoiseprotocol.a $@ + $(QUIET) echo -e "BUILT\t$@" + +# libsodium's ./configure script incorrectly detects that mmap, mlock, madvise, mprotect, +# and raise are all available, when in fact they are not in the enclave. This set of flags +# allows us to undo that. +LIBSODIUM_UNDEFS=-UHAVE_MMAP -UHAVE_MLOCK -UHAVE_MADVISE -UHAVE_MPROTECT -UHAVE_RAISE +##LIBSODIUM_UNDEFS +build/libsodium/TEST.a: + $(QUIET) echo -e "BUILD\t$@" + $(QUIET) mkdir -p $@.dir $(@D) + $(QUIET) (cd libsodium && (git clean -fx || true) && ./configure \ + CFLAGS="$(TEST_CFLAGS)" \ + CXXFLAGS="$(TEST_CXXFLAGS)" \ + CC=$(CC) CXX=$(CXX) --prefix=$$PWD/../$@.dir && $(MAKE) clean && $(MAKE) install) $(QUIET_OUT) + $(QUIET) ln -s $$PWD/$@.dir/lib/libsodium.a $@ + $(QUIET) echo -e "BUILT\t$@" +build/libsodium/SGX.a: | build/libsodium/TEST.a + $(QUIET) echo -e "BUILD\t$@" + $(QUIET) mkdir -p $@.dir $(@D) + $(QUIET) (cd libsodium && (git clean -fx || true) && ./configure \ + CFLAGS="$(SGX_CFLAGS) $(LIBSODIUM_UNDEFS)" \ + CXXFLAGS="$(SGX_CXXFLAGS) $(LIBSODIUM_UNDEFS)" \ + CC=$(CC) CXX=$(CXX) --prefix=$$PWD/../$@.dir && $(MAKE) clean && $(MAKE) install) $(QUIET_OUT) + $(QUIET) ln -s $$PWD/$@.dir/lib/libsodium.a $@ + $(QUIET) echo -e "BUILT\t$@" +build/libsodium/NSM.a: | build/libsodium/SGX.a + $(QUIET) echo -e "BUILD\t$@" + $(QUIET) mkdir -p $@.dir $(@D) + $(QUIET) (cd libsodium && (git clean -fx || true) && ./configure \ + CFLAGS="$(NSM_CFLAGS) $(LIBSODIUM_UNDEFS)" \ + CXXFLAGS="$(NSM_CXXFLAGS) $(LIBSODIUM_UNDEFS)" \ + CC=$(CC) CXX=$(CXX) --prefix=$$PWD/../$@.dir && $(MAKE) clean && $(MAKE) install) $(QUIET_OUT) + $(QUIET) ln -s $$PWD/$@.dir/lib/libsodium.a $@ + $(QUIET) echo -e "BUILT\t$@" + +EDGER8R_FILES=build/svr2/svr2_t.h build/svr2/svr2_t.c build/svr2/svr2_args.h +# This $(firstword) trick allows for grouped targets. +$(filter-out $(firstword $(EDGER8R_FILES)),$(EDGER8R_FILES)): $(firstword $(EDGER8R_FILES)) +$(firstword $(EDGER8R_FILES)): ../shared/svr2.edl + $(QUIET) echo -e "EDGER8\t$(EDGER8R_FILES)" + $(QUIET) mkdir -p $(@D) + $(QUIET) $(OE_EDGER8R) $< --trusted \ + --trusted-dir build/svr2 \ + --search-path $(OE_INCDIR) \ + --search-path $(OE_INCDIR)/openenclave/edl/sgx $(QUIET_OUT) + +generated: $(EDGER8R_FILES) $(PROTO_FILES) +build/%/SGX.a: generated + $(QUIET) $(MAKE) -f Makefile.subdir DIR=$* ENV=SGX $(MAKEFILTER) +build/%/NSM.a: generated + $(QUIET) $(MAKE) -f Makefile.subdir DIR=$* ENV=NSM $(MAKEFILTER) +build/%/TEST.a: generated + $(QUIET) $(MAKE) -f Makefile.subdir DIR=$* ENV=TEST $(MAKEFILTER) +build/%/HOST.a: generated + $(QUIET) $(MAKE) -f Makefile.subdir DIR=$* ENV=HOST $(MAKEFILTER) +.PHONY: build/%/SGX.a build/%/TEST.a build/%/HOST.a build/%/NSM.a + +# All libraries which will become part of enclave.bin. If A depends on B, then A should be added before B. +SGX_LIBRARIES = \ + svr2 \ + ecalls \ + core \ + timeout \ + client \ + db \ + raft \ + groupclock \ + peers \ + peerid \ + sender \ + util \ + context \ + hmac \ + noise \ + noise-c \ + noisewrap \ + env \ + env/sgx \ + sip \ + attestation \ + metrics \ + proto \ + protobuf-lite \ + libsodium \ +## SGX_LIBRARIES + +build/enclave.bin: $(patsubst %,build/%/SGX.a,$(SGX_LIBRARIES)) + $(QUIET) echo -e "BUILD\t$@" + $(QUIET) $(CXX) -o $@ $(SGX_LDFLAGS) $^ $(SGX_LDFLAGS) + +build/enclave.signed: build/enclave.bin build/public.pem build/private.pem svr2.conf + $(QUIET) echo -e "SIGN\t$@" + $(QUIET) $(OE_DIR)/bin/oesign sign -e $< -c svr2.conf -k build/private.pem -o $@ $(QUIET_OUT) + +build/enclave.small: build/enclave.bin build/public.pem build/private.pem svr2_small.conf + $(QUIET) echo -e "SIGN\t$@" + $(QUIET) $(OE_DIR)/bin/oesign sign -e $< -c svr2_small.conf -k build/private.pem -o $@ $(QUIET_OUT) + +build/enclave.test: build/enclave.bin build/public.pem build/private.pem svr2_test.conf + $(QUIET) echo -e "SIGN\t$@" + $(QUIET) $(OE_DIR)/bin/oesign sign -e $< -c svr2_test.conf -k build/private.pem -o $@ $(QUIET_OUT) + +NSM_LIBRARIES = \ + nitromain \ + core \ + timeout \ + client \ + db \ + raft \ + groupclock \ + peers \ + peerid \ + sender \ + util \ + hmac \ + noise \ + noise-c \ + noisewrap \ + env \ + env/nsm \ + sip \ + socketwrap \ + context \ + metrics \ + proto \ + protobuf-lite \ + libsodium \ +## NSM_LIBRARIES + +build/enclave.nsm: $(patsubst %,build/%/NSM.a,$(NSM_LIBRARIES)) + $(QUIET) echo -e "BUILD\t$@" + $(QUIET) $(CXX) -o $@ $(NSM_LDFLAGS) $^ $(NSM_LDFLAGS) + +clean: + $(QUIET) (cd protobuf ; make clean ; git clean -fx ; true) $(QUIET_OUT) + $(QUIET) (cd noise-c ; make clean ; git clean -fx ; true) $(QUIET_OUT) + $(QUIET) (cd SipHash ; make clean ; git clean -fx ; true) $(QUIET_OUT) + $(QUIET) rm -vfr build $(QUIET_OUT) + $(QUIET) rm -vf .testdepends $(QUIET_OUT) + +build/private.pem: + $(QUIET) echo -e "KEY\t$@" + $(QUIET) mkdir -p $(@D) + $(QUIET) openssl genrsa -out $@ -3 3072 $(QUIET_OUT) +build/public.pem: build/private.pem + $(QUIET) echo -e "KEY\t$@" + $(QUIET) openssl rsa -in $< -pubout -out $@ $(QUIET_OUT) + +%.test.out: %.test + $(QUIET) echo -e "TEST\t$<" + $(QUIET) ./$^ --gtest_color=yes &>$@ || (cat $@; false) + $(QUIET) echo -e "TEST\xE2\x9c\x85\t$<" + +build/testhost/libsvr2.a: + $(QUIET) mkdir -p $(@D) + $(CC) -c -o build/testhost/svr2.o $(HOST_CFLAGS) ../host/enclave/c/svr2_u.c + ar rcs $@ build/testhost/svr2.o + +build/testhost.bin: testhost/testhost.cc build/testhost/libsvr2.a build/attestation/HOST.a build/metrics/HOST.a build/proto/HOST.a build/protobuf-lite/HOST.a + $(CXX) -o $@ $(HOST_CXXFLAGS) $(HOST_LDFLAGS) $^ $(HOST_LDFLAGS) + +.testdepends: $(shell find ./ -type f | grep /tests/ | grep cc$) + $(QUIET) ./test_deps.sh $(QUIET_OUT) +include .testdepends + +test: +validate: diff --git a/enclave/Makefile.HOST b/enclave/Makefile.HOST new file mode 100644 index 0000000..a7973f5 --- /dev/null +++ b/enclave/Makefile.HOST @@ -0,0 +1,4 @@ +include Makefile.base + +CFLAGS ?= $(HOST_CFLAGS) +CXXFLAGS ?= $(HOST_CXXFLAGS) diff --git a/enclave/Makefile.NSM b/enclave/Makefile.NSM new file mode 100644 index 0000000..33c2e3c --- /dev/null +++ b/enclave/Makefile.NSM @@ -0,0 +1,4 @@ +include Makefile.base + +CFLAGS ?= $(NSM_CFLAGS) +CXXFLAGS ?= $(NSM_CXXFLAGS) diff --git a/enclave/Makefile.SGX b/enclave/Makefile.SGX new file mode 100644 index 0000000..3ec361e --- /dev/null +++ b/enclave/Makefile.SGX @@ -0,0 +1,4 @@ +include Makefile.base + +CFLAGS ?= $(SGX_CFLAGS) +CXXFLAGS ?= $(SGX_CXXFLAGS) diff --git a/enclave/Makefile.TEST b/enclave/Makefile.TEST new file mode 100644 index 0000000..230f21e --- /dev/null +++ b/enclave/Makefile.TEST @@ -0,0 +1,4 @@ +include Makefile.base + +CFLAGS ?= $(TEST_CFLAGS) +CXXFLAGS ?= $(TEST_CXXFLAGS) diff --git a/enclave/Makefile.base b/enclave/Makefile.base new file mode 100644 index 0000000..3f05361 --- /dev/null +++ b/enclave/Makefile.base @@ -0,0 +1,164 @@ +SHELL=/bin/bash -o pipefail # needed for pipefail +CXX=clang++-11 +CC=clang-11 +OE_DIR ?= /opt/openenclave +OE_EDGER8R = $(OE_DIR)/bin/oeedger8r +ADDITIONAL_CFLAGS ?= +ifeq ($(V),) + QUIET=@ + QUIET_OUT=&>/dev/null +else + QUIET= + QUIET_OUT= +endif + +SECURITY_CFLAGS = \ + -fstack-protector-strong \ + -fstack-clash-protection \ + -mshstk \ + -D_FORTIFY_SOURCE=3 \ + -fsanitize=bounds \ + -fsanitize-undefined-trap-on-error \ +## SECURITY_CFLAGS +BASE_CFLAGS = \ + -fPIC \ + -iquote $(CURDIR) \ + -iquote $(CURDIR)/build \ + -g \ + -DOE_API_VERSION=2 \ + -Wthread-safety \ + -O2 \ + $(SECURITY_CFLAGS) \ + $(ADDITIONAL_CFLAGS) \ +## BASE_CFLAGS +BASE_CXXFLAGS = \ + $(BASE_CFLAGS) \ + -std=c++17 \ +## BASE_CXXFLAGS + +BASE_LDFLAGS = \ + -Wl,-wrap=noise_rand_bytes \ + -Wl,-z,relro \ + -Wl,-z,now \ + -Wl,-z,noexecstack \ + -Wl,-z,separate-code \ +## BASE_LDFLAGS + +LIBRARY_CFLAGS = \ + -I$(CURDIR)/protobuf/src \ + -I$(CURDIR)/noise-c/include \ + -I$(CURDIR)/googletest/googletest/include \ + -I$(CURDIR)/libsodium/src/libsodium/include \ +## LIBRARY_CFLAGS + +TEST_CFLAGS ?= \ + $(BASE_CFLAGS) \ + $(LIBRARY_CFLAGS) \ + -DIS_TEST \ +## TEST_CFLAGS +TEST_CXXFLAGS ?= \ + $(BASE_CXXFLAGS) \ + $(LIBRARY_CFLAGS) \ + -DIS_TEST \ +## TEST_CXXFLAGS +TEST_LDFLAGS ?= \ + $(BASE_LDFLAGS) \ + -lpthread \ +## TEST_LDFLAGS + +OE_CFLAGS ?= $(shell pkg-config oeenclave-clang --cflags) +SGX_CFLAGS ?= \ + $(BASE_CFLAGS) \ + $(OE_CFLAGS) \ + $(LIBRARY_CFLAGS) \ +## SGX_CFLAGS +OE_CXXFLAGS ?= $(shell pkg-config oeenclave-clang++ --cflags) +SGX_CXXFLAGS ?= \ + $(BASE_CXXFLAGS) \ + $(OE_CXXFLAGS) \ + $(LIBRARY_CFLAGS) \ +## SGX_CXXFLAGS +OE_LDFLAGS ?= $(shell pkg-config oeenclave-clang++ --libs) +OE_MBEDTLS_LDFLAGS ?= $(shell pkg-config oeenclave-clang++ --variable=mbedtlslibs) +SGX_LDFLAGS ?= \ + $(BASE_LDFLAGS) \ + $(OE_LDFLAGS) \ + $(OE_MBEDTLS_LDFLAGS) \ +## SGX_LDFLAGS + +NSM_CFLAGS ?= \ + $(BASE_CFLAGS) \ + $(LIBRARY_CFLAGS) \ + -I/opt/nsm \ + -mllvm -x86-speculative-load-hardening \ +## NSM_CFLAGS +NSM_CXXFLAGS ?= \ + $(BASE_CXXFLAGS) \ + $(LIBRARY_CFLAGS) \ + -I/opt/nsm \ +## NSM_CXXFLAGS +NSM_LDFLAGS ?= \ + $(BASE_LDFLAGS) \ + /opt/nsm/libnsm.a \ + -lpthread \ + -lrt \ + -ldl \ +## NSM_LDFLAGS + +OE_INCDIR = $(shell pkg-config oeenclave-clang++ --variable=includedir) + +OE_HOST_CFLAGS ?= $(shell pkg-config oehost-clang --cflags) +OE_HOST_CXXFLAGS ?= $(shell pkg-config oehost-clang++ --cflags) +OE_HOST_LDFLAGS ?= $(shell pkg-config oehost-clang++ --libs) +OE_HOST_MBEDTLS_LDFLAGS ?= $(shell pkg-config oehost-clang++ --variable=mbedtlslibs) +HOST_CFLAGS ?= \ + $(BASE_CFLAGS) \ + $(OE_HOST_CFLAGS) \ + $(LIBRARY_CFLAGS) \ +## HOST_CFLAGS +HOST_CXXFLAGS ?= \ + $(BASE_CXXFLAGS) \ + $(OE_HOST_CXXFLAGS) \ + $(LIBRARY_CFLAGS) \ +## HOST_CXXFLAGS +HOST_LDFLAGS ?= \ + $(BASE_LDFLAGS) \ + $(OE_HOST_LDFLAGS) \ + $(OE_HOST_MBEDTLS_LDFLAGS) \ +## HOST_LDFLAGS + +WARNING_CFLAGS ?= \ + -Werror \ + -Wall \ + -Wextra \ + -Wpedantic \ + -Walloca \ + -Wcast-qual \ + -Wformat=2 \ + -Wformat-security \ + -Wnull-dereference \ + -Wstack-protector \ + -Wvla \ + -Warray-bounds \ + -Warray-bounds-pointer-arithmetic \ + -Wassign-enum \ + -Wbad-function-cast \ + -Wfloat-equal \ + -Wformat-type-confusion \ + -Widiomatic-parentheses \ + -Wimplicit-fallthrough \ + -Wloop-analysis \ + -Wpointer-arith \ + -Wshift-sign-overflow \ + -Wtautological-constant-in-range-compare \ + -Wunreachable-code-aggressive \ + -Wthread-safety \ + -Wthread-safety-beta \ + -Wcomma \ + -Wno-unused-parameter \ + -Wno-bitwise-op-parentheses \ + -Wno-shift-op-parentheses \ + -Wno-c++20-designator \ + -Wno-zero-length-array \ + -Wno-c99-extensions \ +##WARNING_CFLAGS diff --git a/enclave/Makefile.subdir b/enclave/Makefile.subdir new file mode 100644 index 0000000..028a1d4 --- /dev/null +++ b/enclave/Makefile.subdir @@ -0,0 +1,41 @@ +include Makefile.$(ENV) + +all: +.PHONY: all + +BUILD = build/$(DIR) +$(BUILD): + $(QUIET) echo -e "MKDIR\t$@" + $(QUIET) mkdir -p $(BUILD) + +# We use WARNING_CFLAGS only when the file exists, is not a symlink, +# and isn't generated code (IE: it's not in build/...) +NO_WARNINGS=-w + +OBJECTS=$(patsubst %,build/%.$(ENV).o,$(wildcard $(DIR)/*.c) $(wildcard $(DIR)/*.cc)) $(patsubst %,%.$(ENV).o,$(wildcard $(BUILD)/*.c) $(wildcard $(BUILD)/*.cc)) + +$(BUILD)/%.cc.$(ENV).d $(BUILD)/%.cc.$(ENV).o: $(BUILD)/%.cc | $(BUILD) + $(QUIET) echo -e "CXX\t$<.$(ENV)" + $(QUIET) $(CXX) -c -o $(BUILD)/$*.cc.$(ENV).o $(CXXFLAGS) -MF $(BUILD)/$*.cc.$(ENV).d -MMD $< $(NO_WARNINGS) + +$(BUILD)/%.c.$(ENV).d $(BUILD)/%.c.$(ENV).o: $(BUILD)/%.c | $(BUILD) + $(QUIET) echo -e "CC\t$<.$(ENV)" + $(QUIET) $(CC) -c -o $(BUILD)/$*.c.$(ENV).o $(CFLAGS) -MF $(BUILD)/$*.c.$(ENV).d -MMD $< $(NO_WARNINGS) + +$(BUILD)/%.cc.$(ENV).d $(BUILD)/%.cc.$(ENV).o: $(DIR)/%.cc | $(BUILD) + $(QUIET) echo -e "CXX\t$<.$(ENV)" + $(QUIET) $(CXX) -c -o $(BUILD)/$*.cc.$(ENV).o $(CXXFLAGS) -MF $(BUILD)/$*.cc.$(ENV).d -MMD $< \ + $(shell if [ ! -L $(DIR)/$*.cc ]; then echo $(WARNING_CFLAGS); else echo $(NO_WARNINGS); fi) + +$(BUILD)/%.c.$(ENV).d $(BUILD)/%.c.$(ENV).o: $(DIR)/%.c | $(BUILD) + $(QUIET) echo -e "CC\t$<.$(ENV)" + $(QUIET) $(CC) -c -o $(BUILD)/$*.c.$(ENV).o $(CFLAGS) -MF $(BUILD)/$*.c.$(ENV).d -MMD $< \ + $(shell if [ ! -L $(DIR)/$*.c ]; then echo $(WARNING_CFLAGS); else echo $(NO_WARNINGS); fi) + +$(BUILD)/$(ENV).a: $(OBJECTS) | $(BUILD) + $(QUIET) echo -e "AR\t$@" + $(QUIET) ar rcs $@ $^ + +$(foreach f,$(patsubst %.o,%.d,$(OBJECTS)),$(eval include $f)) + +all: $(BUILD)/$(ENV).a diff --git a/enclave/README.md b/enclave/README.md new file mode 100644 index 0000000..b24ca86 --- /dev/null +++ b/enclave/README.md @@ -0,0 +1,159 @@ +# SVR2 Enclave Code + +SVR2 uses C++ as its language for building an in-enclave binary, with the +OpenEnclave (hereafter 'OE') SDK. The binary, `enclave.bin` is then signed via +OE's `oesign`, which doesn't matter to us because we don't trust the signature, +just the unique ID (SGX "mrenclave") of the resulting signed config. However, +the `oesign` process does one important thing: it binds a config (either +`svr2_test.conf` or `svr2.conf` to the resulting object. Once this process +is complete, the resulting `enclave.signed` file is ready to be loaded into a +DCAP-based SGX enclave. + +# Host/enclave communication + +Most (all?) host/enclave communication happens through a single ocall/ecall +combination, defined in `../shared/svr2.edl`: + +- `svr2_input_message`: Enclave receives a message (a serialized + `HostToEnclaveMessage` protobuf) from the host. +- `svr2_output_message`: Enclave sends a message (a serialized + `EnclaveToHostMessage` protobuf) to the host. + +The enclave expects (and enforces, via locking) that the `svr2_input_message` +function is called sequentially. It also guarantees that it will call the +`svr2_output_message` function only during such a call, and sequentially. + +Thus, you might get the following control flow: + +``` +svr2_input_message(HostToEnclaveMessage1) + svr2_output_message(EnclaveToHostMessage1.1) + svr2_output_message(EnclaveToHostMessage1.2) + svr2_output_message(EnclaveToHostMessage1.3) +svr2_input_message(HostToEnclaveMessage2) + svr2_output_message(EnclaveToHostMessage2.1) + svr2_output_message(EnclaveToHostMessage2.2) +svr2_input_message(HostToEnclaveMessage3) +svr2_input_message(HostToEnclaveMessage4) + svr2_output_message(HostToEnclaveMessage4.1) +``` + +In such a flow, we can reason that input message 1 has output messages +1.1, 1.2, and 1.3, etc. Note that each input can have an arbitrary number of +outputs, including zero (e.g. message 3). In other words, the enclave can +be treated as a function: + +``` +func CallEnclave(HostToEnclaveMessage) []EnclaveToHostMessage +``` + +taking in a single HostToEnclaveMessage and returning a list of zero or more +EnclaveToHostMessages. + +Certain messages are 'transactions', or messages with a `Request` that want +a specific `Reply`. It is important to note that if a request is +passed in via a message, the response associated with it may not be part of +the returned list. IE: the host may pass in a transaction request, above, +via `EnclaveToHostMessage1`, but may not get back the reply until +`HostToEnclaveMessage4.1`. Transactions have associated transaction IDs, +which allow for disambiguating requests and their associated responses. +Hosts may send transactions to enclaves and enclaves to hosts. Each direction +maintains a unique keyspace for transaction IDs (so HostToEnclave transaction 1 +and EnclaveToHost transaction 1 are distinct), and each is responsible for +making sure that transaction requests that they pass are uniquely identified. + +## Code Layout + +Code is broken into a set of modules, where each module is a one-level-deep +subdirectory within the top-level `enclave` directory. Each module is +independently compiled, then all modules are combined in a final linking step +to form the resulting binary. Modules are listed as `LIBRARIES` within +`Makefile`, and must form a DAG of dependencies. Within the `LIBRARIES` list, +higher libraries may depend on lower libraries, but not vice versa. + +Code roughly follows the [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html). + +# Concurrency in SVR2 Enclave + +With SVR2, we're aiming to utilize a single replica group to serve all traffic. +This, of course, brings up issues around scalability. We can of course add +new replicas to the replica group, but with a strong consensus model relying +on agreement between a quorum (in our case, a simple majority) of voting +replicas, additional replicas have the potential to add load rather than shed +it. + +To handle this, SVR2 is built to, as much as possible, utilize the resources +of non-leader and non-voting replicas. While we're unable to shed or reduce +load on RAM with added members (each replica needs to store the entire +database), we can shed load in the form of CPU and network resources. + +## Utilizing multiple cores + +Even without considering excess replicas, we aim to utilize the resources +of each replica to the fullest extent. To do this, the SVR2 enclave binary +is built as a true multi-threaded process, with targetted locking of code +subsections allowing parallel processing as much as possible. + +One of the most CPU-intensive tasks that SVR2 partakes in is encryption +and decryption. This takes place when replicas communicate with each +other ("peer communication") and when they accept and service connections +from clients ("client communication"). When establishing these secure +connections, the initial handshake is more CPU-intensive, followed +by less intensive block cipher encryption/decryption. Peer communication +uses long-lived sessions that amortize handshake cost over a long period +of time, while client communication requires a new handshake, a small +amount of communication, and a subsequent closing of the connection. + +For both peer and client communication, we aim to be highly parallel on +a single machine: handshaking and block-cipher encryption/decryption +are done with client- and peer-level locks, rather than global ones. +This approach, though, lays some requirements on the host side, as +for both cases, reordering of messages breaks the block-cipher +assumptions of the clients/peers. Internally, SVR2 maintains correct +order of messages it outputs to peers and clients: if message A +to a peer or client happens before message B, then `svr2_output_message(A)` +will be called and allowed to complete before `svr2_output_message(B)`. +However, on the host side, care must be taken to respect this +ordering: when messages are forwarded externally or received from external +hosts, their calls to `svr2_input_message` should follow the same pattern: +if A is received before B in either a peer or client stream, then +`svr2_input_message(A)` should be called and allowed to complete before +`svr2_input_message(B)` is called. + +Some global locks are of course still required, in particular around Raft +and its associated logs/database. However, these locked sections are kept +at a minimimum, with as much work done as possible before/after the locks +are acquired. + +## Utilizing multiple machines + +The primary means to scale SVR2 is the addition of replicas. However, +as mentioned, this has the potential to hinder scaling, especially if +the leader alone is allowed to perform CPU-intensive tasks like servicing +client requests. For this reason, SVR2 is built to allow any replica to +service requests from any client. + +When a client connects to SVR2 in a non-leader replica, it will perform +the client handshake and receive/decrypt the client's request entirely +on its own. Once it has done so, it will forward the request to the current +leader as an enclave-to-enclave transaction, receiving in response either +a failure or a log location (an `(index, term)` pair) associated with the +write. Failures are immediately returned to the client. A success, though, +creates a watch-point in the non-leader replica's raft log at `index`. +The replica will wait until `index` is a committed part of its own log (via +normal Raft `AppendEntries` mechanisms), then will check the `term` of the +committed log. If that matches the `term` returned from its write request, +by definition the log at `index` contains the client's request, and when +applying that request to its local database, it can safely return the +response to that client over its still-open channel. + +By this mechanism, load (especially client handshake and communication +load) can be shared across all replicas. Crucially, this includes +non-voting replicas, which can be added with minimal increase to the +load on the voting replicas. As non-voting replicas still receive +Raft logs and their commitments, they can happily service client +requests. + +## More topics +- [Raft healing](../docs/Healing.md) +- [Enclave messages](../docs/Messages.md) diff --git a/enclave/SipHash b/enclave/SipHash new file mode 160000 index 0000000..eee7d0d --- /dev/null +++ b/enclave/SipHash @@ -0,0 +1 @@ +Subproject commit eee7d0d84dc7731df2359b243aa5e75d85f6eaef diff --git a/enclave/attestation/attestation.cc b/enclave/attestation/attestation.cc new file mode 100644 index 0000000..2f6a683 --- /dev/null +++ b/enclave/attestation/attestation.cc @@ -0,0 +1,90 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include "attestation/attestation.h" + +#include +#include +#include + + +#include + +#include "noise/noise.h" +#include "metrics/metrics.h" +#include "proto/error.pb.h" +#include "util/macros.h" + +namespace svr2::attestation { +const oe_uuid_t sgx_remote_uuid = {OE_FORMAT_UUID_SGX_ECDSA}; + +/** + * Helper function used to make the claim-finding process more convenient. Given + * the claim name, claim list, and its size, returns the claim with that claim + * name in the list. + */ +const oe_claim_t* FindClaim(const oe_claim_t* claims, size_t claims_size, + const char* name) { + for (size_t i = 0; i < claims_size; i++) { + if (strcmp(claims[i].name, name) == 0) return &(claims[i]); + } + return nullptr; +} + +error::Error ReadKeyFromVerifiedClaims(oe_claim_t* claims, size_t claims_length, + std::array& out) { + const oe_claim_t* claim; + oe_claim_t* custom_claims = nullptr; + size_t custom_claims_length = 0; + + // read the custom claims + if ((claim = FindClaim(claims, claims_length, + OE_CLAIM_CUSTOM_CLAIMS_BUFFER)) == nullptr) { + return COUNTED_ERROR(Env_CustomClaimsMissing); + } + + // deserialize custom claims + if (oe_deserialize_custom_claims(claim->value, claim->value_size, + &custom_claims, + &custom_claims_length) != OE_OK) { + return COUNTED_ERROR(Env_CustomClaimsDeserialize); + } + + auto free_custom_claims_known_size = [custom_claims_length](oe_claim_t* ptr) { + return oe_free_custom_claims(ptr, custom_claims_length); + }; + std::unique_ptr + free_custom_claims(custom_claims, free_custom_claims_known_size); + + // There is one custom claim with name "pk". The value is the key we will + // return. + if (strcmp(custom_claims[0].name, "pk") != 0) { + return COUNTED_ERROR(Env_AttestationPubkeyMissing); + } + + if (custom_claims[0].value_size != out.size()) { + return COUNTED_ERROR(Env_AttestationPubkeyInvalidSize); + } + + std::copy(custom_claims[0].value, + custom_claims[0].value + custom_claims[0].value_size, out.begin()); + return error::OK; +} + +std::pair VerifyAndReadClaims( + const std::string& evidence, const std::string& endorsements) { + const uint8_t* evidence_data = + reinterpret_cast(evidence.data()); + const uint8_t* endorsements_data = + reinterpret_cast(endorsements.data()); + oe_claim_t* claims = nullptr; + size_t claims_length = 0; + CHECK(OE_OK == oe_verify_evidence(&sgx_remote_uuid, evidence_data, + evidence.size(), endorsements_data, + endorsements.size(), nullptr, 0, &claims, + &claims_length)); + + return std::make_pair(claims, claims_length); +} + +}; // namespace svr2::attestation diff --git a/enclave/attestation/attestation.h b/enclave/attestation/attestation.h new file mode 100644 index 0000000..9348475 --- /dev/null +++ b/enclave/attestation/attestation.h @@ -0,0 +1,51 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_ATTESTATION_ATTESTATION_H__ +#define __SVR2_ATTESTATION_ATTESTATION_H__ + +#include + +#include +#include + +#include "proto/error.pb.h" + + +namespace svr2::attestation { + +extern const oe_uuid_t sgx_remote_uuid; + +/** + * Helper function used to make the claim-finding process more convenient. Given + * the claim name, claim list, and its size, returns the claim with that claim + * name in the list. + */ +const oe_claim_t* FindClaim(const oe_claim_t* claims, size_t claims_size, + const char* name); +/** + * Deserializes Open Enclave format custom claims then finds, validates, + * and returns the public key claim. + * + * claims serialized OpenEnclave claims + * claims_length number of claims + * out: array where public key will be written + * returns Env_CustomClaimsMissing, Env_CustomClaimsDeserialize, + * Env_AttestationPubkeyMissing, Env_AttestationPubkeyInvalidSize + */ +error::Error ReadKeyFromVerifiedClaims(oe_claim_t* claims, size_t claims_length, + std::array& out); + +/** + * Verifies evidence and endorsements and returns the parsed array + * of claims in Open Enclave format. + * + * The returned pointer most be freed with `oe_free_claims` + */ +std::pair VerifyAndReadClaims( + const std::string& evidence, const std::string& endorsements); + +}; // namespace svr2::attestation + + +#endif // __SVR2_ATTESTATION_ATTESTATION_H__ diff --git a/enclave/client/client.cc b/enclave/client/client.cc new file mode 100644 index 0000000..dc45a96 --- /dev/null +++ b/enclave/client/client.cc @@ -0,0 +1,225 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include "client/client.h" + +#include + +#include "env/env.h" +#include "util/log.h" +#include "metrics/metrics.h" + +namespace svr2::client { + +static std::atomic id_gen{1}; + +const NoiseProtocolId client_protocol = { + .prefix_id = NOISE_PREFIX_STANDARD, + .pattern_id = NOISE_PATTERN_NK, + .dh_id = NOISE_DH_CURVE25519, + .cipher_id = NOISE_CIPHER_CHACHAPOLY, + .hash_id = NOISE_HASH_SHA256, + .hybrid_id = 0, +}; + +Client::Client(const std::string& authenticated_id) + : hs_(noise::WrapHandshakeState(nullptr)), + tx_(noise::WrapCipherState(nullptr)), + rx_(noise::WrapCipherState(nullptr)), + id_(id_gen.fetch_add(1)), + authenticated_id_(authenticated_id) { +} + +Client::~Client() { +} + +error::Error Client::Init(const noise::DHState& dhstate, const e2e::Attestation& attestation) { + util::unique_lock lock(mu_); + NoiseHandshakeState* hs; + if (NOISE_ERROR_NONE != noise_handshakestate_new_by_id(&hs, &client_protocol, NOISE_ROLE_RESPONDER)) { + return COUNTED_ERROR(Client_HandshakeState); + } + auto hs_wrap = noise::WrapHandshakeState(hs); + if (NOISE_ERROR_NONE != noise_dhstate_copy( + noise_handshakestate_get_local_keypair_dh(hs), + dhstate.get())) { + return COUNTED_ERROR(Client_CopyDHState); + } + if (NOISE_ERROR_NONE != noise_handshakestate_start(hs)) { + return COUNTED_ERROR(Client_HandshakeStart); + } + hs_start_.mutable_test_only_pubkey()->resize(32, '\0'); + if (NOISE_ERROR_NONE != noise_dhstate_get_public_key( + dhstate.get(), + noise::StrU8Ptr(hs_start_.mutable_test_only_pubkey()), + hs_start_.mutable_test_only_pubkey()->size())) { + return COUNTED_ERROR(Client_ExtractPublicKey); + } + *hs_start_.mutable_evidence() = attestation.evidence(); + *hs_start_.mutable_endorsement() = attestation.endorsements(); + hs_.swap(hs_wrap); + return error::OK; +} + +std::pair Client::FinishHandshake(context::Context* ctx, const std::string& data) { + ACQUIRE_LOCK(mu_, ctx, lock_client); + MEASURE_CPU(ctx, cpu_client_hs_finish); + if (!hs_.get() || tx_.get() || rx_.get() + || noise_handshakestate_get_action(hs_.get()) != NOISE_ACTION_READ_MESSAGE) { + return std::make_pair("", COUNTED_ERROR(Client_HandshakeState)); + } + std::string buffer = data; + NoiseBuffer read_buf = noise::BufferInputFromString(&buffer); + if (NOISE_ERROR_NONE != noise_handshakestate_read_message(hs_.get(), &read_buf, nullptr)) { + return std::make_pair("", COUNTED_ERROR(Client_FinishReadHandshake)); + } + if (NOISE_ACTION_WRITE_MESSAGE != noise_handshakestate_get_action(hs_.get())) { + return std::make_pair("", COUNTED_ERROR(Client_HandshakeState)); + } + buffer.resize(noise::HANDSHAKE_INIT_SIZE, '\0'); + NoiseBuffer write_buf = noise::BufferOutputFromString(&buffer); + if (NOISE_ERROR_NONE != noise_handshakestate_write_message(hs_.get(), &write_buf, nullptr)) { + return std::make_pair("", COUNTED_ERROR(Client_FinishWriteHandshake)); + } + buffer.resize(write_buf.size); + if (NOISE_ACTION_SPLIT != noise_handshakestate_get_action(hs_.get())) { + return std::make_pair("", COUNTED_ERROR(Client_HandshakeState)); + } + NoiseCipherState* tx; + NoiseCipherState* rx; + if (NOISE_ERROR_NONE != noise_handshakestate_split(hs_.get(), &tx, &rx)) { + return std::make_pair("", COUNTED_ERROR(Client_FinishSplit)); + } + tx_.reset(tx); + rx_.reset(rx); + hs_.reset(nullptr); + return std::make_pair(buffer, error::OK); +} + +error::Error Client::DecryptRequest(context::Context* ctx, const std::string& data, google::protobuf::MessageLite* request) { + ACQUIRE_LOCK(mu_, ctx, lock_client); + MEASURE_CPU(ctx, cpu_client_decrypt); + if (hs_.get() || !tx_.get() || !rx_.get()) { + return COUNTED_ERROR(Client_DecryptState); + } + auto [plaintext, err] = noise::Decrypt(rx_.get(), data); + if (err != error::OK) { + return err; + } + if (!request->ParseFromString(plaintext)) { + return COUNTED_ERROR(Client_DecryptParse); + } + return error::OK; +} + +std::pair Client::EncryptResponse(context::Context* ctx, const google::protobuf::MessageLite& response) { + ACQUIRE_LOCK(mu_, ctx, lock_client); + MEASURE_CPU(ctx, cpu_client_encrypt); + if (hs_.get() || !tx_.get() || !rx_.get()) { + return std::make_pair("", COUNTED_ERROR(Client_EncryptState)); + } + std::string plaintext; + if (!response.SerializeToString(&plaintext)) { + return std::make_pair("", COUNTED_ERROR(Client_EncryptSerialize)); + } + return noise::Encrypt(tx_.get(), plaintext); +} + +std::pair ClientManager::NewClient(context::Context* ctx, const std::string& authenticated_id) { + ACQUIRE_LOCK(mu_, ctx, lock_clientmanager); + MEASURE_CPU(ctx, cpu_client_hs_start); + std::unique_ptr c(new Client(authenticated_id)); + error::Error err = c->Init(dhstate_, attestation_); + if (err != error::OK) { + return std::make_pair(nullptr, err); + } + Client* ptr = c.get(); + clients_[ptr->ID()] = std::move(c); + GAUGE(client, clients)->Set(clients_.size()); + COUNTER(client, created)->Increment(); + return std::make_pair(ptr, error::OK); +} + +Client* ClientManager::GetClient(context::Context* ctx, ClientID id) const { + ACQUIRE_LOCK(mu_, ctx, lock_clientmanager); + auto find = clients_.find(id); + if (find == clients_.end()) { return nullptr; } + return find->second.get(); +} + +bool ClientManager::RemoveClient(context::Context* ctx, ClientID id) { + ACQUIRE_LOCK(mu_, ctx, lock_clientmanager); + auto find = clients_.find(id); + if (find == clients_.end()) { return false; } + clients_.erase(find); + GAUGE(client, clients)->Set(clients_.size()); + COUNTER(client, closed)->Increment(); + return true; +} + +noise::DHState ClientManager::NewDHState() { + COUNTER(client, new_dh_state)->Increment(); + noise::DHState out = noise::WrapDHState(nullptr); + NoiseDHState* dhstate; + if (NOISE_ERROR_NONE != noise_dhstate_new_by_id(&dhstate, client::client_protocol.dh_id)) { + return out; + } + noise::DHState client_dh = noise::WrapDHState(dhstate); + if (NOISE_ERROR_NONE != noise_dhstate_generate_keypair(dhstate)) { + return out; + } + client_dh.swap(out); + return out; +} + +error::Error ClientManager::RotateKeyAndRefreshAttestation(context::Context* ctx, const enclaveconfig::RaftGroupConfig& config) { + auto dhstate = NewDHState(); + auto [attestation, err] = GetAttestation(dhstate, config); + if (err != error::OK) { + COUNTER(client, key_rotate_failure)->Increment(); + return err; + } + ACQUIRE_LOCK(mu_, ctx, lock_clientmanager); + dhstate_.swap(dhstate); + attestation_.CopyFrom(attestation); + COUNTER(client, key_rotate_success)->Increment(); + return error::OK; +} + +error::Error ClientManager::RefreshAttestation(context::Context* ctx, const enclaveconfig::RaftGroupConfig& config) { + auto dhstate = DHState(ctx); + auto [attestation, err] = GetAttestation(DHState(ctx), config); + if (err != error::OK) { + COUNTER(client, attestation_refresh_failure)->Increment(); + return err; + } + ACQUIRE_LOCK(mu_, ctx, lock_clientmanager); + attestation_.CopyFrom(attestation); + // There's a chance that a RotateKeyAndRefreshAttestation call + // could have happened between when we got dhstate and when we're + // setting attestation here... reset to the one we received just + // in case. + dhstate_.swap(dhstate); + COUNTER(client, attestation_refresh_success)->Increment(); + return error::OK; +} + +std::pair ClientManager::GetAttestation(const noise::DHState& dhstate, const enclaveconfig::RaftGroupConfig& config) { + e2e::Attestation attestation; + // get attestation for its public key + uint8_t public_key[32]; + if (NOISE_ERROR_NONE != noise_dhstate_get_public_key(dhstate.get(), public_key, sizeof(public_key))) { + return std::make_pair(attestation, error::Peers_NewKeyPublic); + } + + env::PublicKey public_key_array {}; + std::copy(std::begin(public_key), std::end(public_key), std::begin(public_key_array)); + return env::environment->Evidence(public_key_array, config); +} + +noise::DHState ClientManager::DHState(context::Context* ctx) const { + ACQUIRE_LOCK(mu_, ctx, lock_clientmanager); + return noise::CloneDHState(dhstate_); +} + +} // namespace svr2::client diff --git a/enclave/client/client.h b/enclave/client/client.h new file mode 100644 index 0000000..7762cc5 --- /dev/null +++ b/enclave/client/client.h @@ -0,0 +1,86 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_CLIENT_CLIENT_H__ +#define __SVR2_CLIENT_CLIENT_H__ + +#include +#include "proto/error.pb.h" +#include "proto/e2e.pb.h" +#include "noise/noise.h" +#include "sip/hasher.h" +#include "util/endian.h" +#include "util/mutex.h" +#include "context/context.h" + +namespace svr2::client { + +class ClientManager; +typedef uint64_t ClientID; +extern const NoiseProtocolId client_protocol; + +class Client { + public: + ClientID ID() const { return id_; } + // Returns ClientHandshakeStart, with std::move semantics, so this + // function should be used only once. + ClientHandshakeStart MovedHandshakeStart() EXCLUDES(mu_) { + util::unique_lock lock(mu_); + return std::move(hs_start_); + } + + bool Handshaking() const { + util::unique_lock lock(mu_); + return hs_.get() != nullptr; + } + std::pair FinishHandshake(context::Context* ctx, const std::string& data) EXCLUDES(mu_); + + error::Error DecryptRequest(context::Context* ctx, const std::string& data, google::protobuf::MessageLite* request) EXCLUDES(mu_); + std::pair EncryptResponse(context::Context* ctx, const google::protobuf::MessageLite& response) EXCLUDES(mu_); + + const std::string& authenticated_id() const { return authenticated_id_; } + + private: + ~Client(); + explicit Client(const std::string& authenticated_id); + error::Error Init(const noise::DHState& dhstate, const e2e::Attestation& attestation) EXCLUDES(mu_); + friend class ClientManager; + friend std::unique_ptr::deleter_type; + + mutable util::mutex mu_; + ClientHandshakeStart hs_start_ GUARDED_BY(mu_); + noise::HandshakeState hs_ GUARDED_BY(mu_); + noise::CipherState tx_ GUARDED_BY(mu_); + noise::CipherState rx_ GUARDED_BY(mu_); + const size_t id_; + const std::string authenticated_id_; +}; + +class ClientManager { + public: + ClientManager(noise::DHState dhstate) : dhstate_(std::move(dhstate)) {} + error::Error RefreshAttestation(context::Context* ctx, const enclaveconfig::RaftGroupConfig& config) EXCLUDES(mu_); + error::Error RotateKeyAndRefreshAttestation(context::Context* ctx, const enclaveconfig::RaftGroupConfig& config) EXCLUDES(mu_); + static noise::DHState NewDHState(); + + std::pair NewClient(context::Context* ctx, const std::string& authenticated_id) EXCLUDES(mu_); + Client* GetClient(context::Context* ctx, ClientID id) const EXCLUDES(mu_); + // Deallocate and remove a client by its ID. + // Client pointers are owned by the ClientManager and can only be deallocated + // via a call to RemoveClient. + bool RemoveClient(context::Context* ctx, ClientID id) EXCLUDES(mu_); + + private: + noise::DHState DHState(context::Context* ctx) const EXCLUDES(mu_); + static std::pair GetAttestation(const noise::DHState& dhstate, const enclaveconfig::RaftGroupConfig& config); + + mutable util::mutex mu_; + noise::DHState dhstate_ GUARDED_BY(mu_); + e2e::Attestation attestation_ GUARDED_BY(mu_); + std::unordered_map> clients_ GUARDED_BY(mu_); + +}; + +} // namespace svr2::client + +#endif // __SVR2_CLIENT_CLIENT_H__ diff --git a/enclave/context/context.cc b/enclave/context/context.cc new file mode 100644 index 0000000..087a460 --- /dev/null +++ b/enclave/context/context.cc @@ -0,0 +1,47 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include "context/context.h" +#include "metrics/metrics.h" +#include "util/cpu.h" + +namespace svr2::context { + +Context::Context() : cpu_current_(nullptr), cpu_top_(nullptr, COUNTER(context, cpu_uncategorized)) { + cpu_top_.SetContext(this); +} + +CPUMeasurement::CPUMeasurement(Context* ctx, metrics::Counter* counter) + : ctx_(nullptr), counter_(counter), ticks_(util::asm_rdtsc()) { + if (ctx != nullptr) { + SetContext(ctx); + } +} + +CPUMeasurement::~CPUMeasurement() { + uint64_t ticks = util::asm_rdtsc(); + counter_->IncrementBy(ticks - ticks_); + if (parent_ != nullptr) { + parent_->ticks_ = ticks; + } + ctx_->cpu_current_ = parent_; +} + +void CPUMeasurement::SetContext(Context* ctx) { + CHECK(ctx_ == nullptr); + ctx_ = ctx; + parent_ = ctx_->cpu_current_; + ctx_->cpu_current_ = this; + if (parent_ != nullptr) { + // If there's a parent CPUMeasurement, increment its ticks-so-far. + // When we're destroyed, we'll push parent_->ticks_ forward so ticks + // during our lifetime are not double-counted. + parent_->counter_->IncrementBy(ticks_ - parent_->ticks_); + } +} + +CPUMeasurement Context::MeasureCPU(metrics::Counter* counter) { + return CPUMeasurement(this, counter); +} + +} // namespace svr2::context diff --git a/enclave/context/context.h b/enclave/context/context.h new file mode 100644 index 0000000..3ab7b59 --- /dev/null +++ b/enclave/context/context.h @@ -0,0 +1,101 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_CONTEXT_CONTEXT_H__ +#define __SVR2_CONTEXT_CONTEXT_H__ + +#include +#include + +#include "util/macros.h" +#include "metrics/metrics.h" +#include "util/mutex.h" + +namespace svr2::context { + +class Context; + +// Class CPUMeasurement allows for counting of CPU ticks used in parts of code. +// On creation, it records the number of CPU ticks, and on destruction it adds +// those ticks to a counter. It's not stand-alone - use Context.MeasureCPU or +// better-yet use the MEASURE_CPU macro. +// +// Usage: +// +// void Foo(ctx) { +// MEASURE_CPU(ctx, cpu_foo); +// ... stuff #1 ... +// Bar(ctx) +// ... stuff #2 ... +// } +// void Bar(ctx) { +// MEASURE_CPU(ctx, cpu_bar); // turns off cpu_foo ticking, then back on when destroyed +// ... stuff #3 ... +// } +// +// This will count CPU ticks of stuff#1 and stuff#2 (but NOT stuff#3) in +// COUNTER(context, cpu_foo), and measure stuff#3 in COUNTER(context, cpu_bar). +class CPUMeasurement { + public: + ~CPUMeasurement(); + private: + friend class Context; + CPUMeasurement(Context* ctx, metrics::Counter* counter); + void SetContext(Context* ctx); + + Context* ctx_; + metrics::Counter* counter_; + CPUMeasurement* parent_; // Provides a singly-linked list back to parent CPUMeasurements. + uint64_t ticks_; +}; + +class Context { + public: + DELETE_COPY_AND_ASSIGN(Context); + Context(); + + // Protobuf creates a protobuf of type that has a lifetime tied + // to the lifetime of this Context (IE: when this context falls out of scope, + // it will be cleaned up) using a protobuf Arena. This optimization allows + // for much faster creation and deletion of intermediate protobufs. However, + // care should be taken to not store the output of this function long-term + // in a class that will live beyond the scope of this Context, as the pointer + // will be invalidated at that time. + template + T* Protobuf() { + return google::protobuf::Arena::CreateMessage(&arena_); + } + + CPUMeasurement MeasureCPU(metrics::Counter* counter); + + // All protobufs returned by Protobuf() are no longer valid after this call. + void GarbageCollectProtobufs() { arena_.Reset(); } + + private: + friend class CPUMeasurement; + google::protobuf::Arena arena_; + CPUMeasurement* cpu_current_; + CPUMeasurement cpu_top_; +}; + +} // namespace svr2::context + +#define MEASURE_CPU(ctx, name) \ + ::svr2::context::CPUMeasurement __cpumeasure_ ## __COUNTER__ = (ctx)->MeasureCPU(COUNTER(context, name)) + +// Creates an RAII util::unique_lock named `lockname`. Use this +// if you need to do things with the lock after you create it (e.g., explicitly +// calling `unlock()`). +#define ACQUIRE_NAMED_LOCK(lockname, mu, ctx, name) \ + util::unique_lock lockname(mu, std::defer_lock); \ + { \ + MEASURE_CPU(ctx, name); \ + lockname.lock(); \ + } +// Creates an RAII util::unique_lock on the given mu with an arbitrary +// name, for when you need `mu` locked but you're not doing anything +// tricky with it like manually unlocking it after. This is more like +// std::lock_guard. +#define ACQUIRE_LOCK(mu, ctx, name) ACQUIRE_NAMED_LOCK(__lock_ ## __COUNTER__, mu, ctx, name) + +#endif // __SVR2_CONTEXT_CONTEXT_H__ diff --git a/enclave/context/tests/acquire_lock.cc b/enclave/context/tests/acquire_lock.cc new file mode 100644 index 0000000..526378f --- /dev/null +++ b/enclave/context/tests/acquire_lock.cc @@ -0,0 +1,69 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +//TESTDEP gtest +//TESTDEP context +//TESTDEP metrics +//TESTDEP proto +//TESTDEP protobuf-lite +#include +#include "context/context.h" +#include "util/macros.h" +#include "util/mutex.h" +#include +#include +#include + +namespace svr2::util { + +class AcquireLockTest : public ::testing::Test { + public: + util::mutex mu; + int in_use GUARDED_BY(mu) = 0; + + static void* AcquireAndSleep(void* in) { + auto t = (AcquireLockTest*) in; + context::Context ctx; + ACQUIRE_LOCK(t->mu, &ctx, lock_test); + t->in_use++; + for (int i = 0; i < 10; i++) { + usleep(100000); + CHECK(t->in_use == 1); + } + t->in_use--; + CHECK(t->in_use == 0); + return NULL; + } + + static void* AcquireNamedAndSleep(void* in) { + auto t = (AcquireLockTest*) in; + context::Context ctx; + ACQUIRE_NAMED_LOCK(lock, t->mu, &ctx, lock_test); + t->in_use++; + for (int i = 0; i < 10; i++) { + usleep(100000); + CHECK(t->in_use == 1); + } + t->in_use--; + CHECK(t->in_use == 0); + return NULL; + } +}; + +TEST_F(AcquireLockTest, Unnamed) { + pthread_t t1, t2, t3, t4; + auto start = time(NULL); + CHECK(0 == pthread_create(&t1, NULL, &AcquireLockTest::AcquireAndSleep, this)); + CHECK(0 == pthread_create(&t2, NULL, &AcquireLockTest::AcquireNamedAndSleep, this)); + CHECK(0 == pthread_create(&t3, NULL, &AcquireLockTest::AcquireAndSleep, this)); + CHECK(0 == pthread_create(&t4, NULL, &AcquireLockTest::AcquireNamedAndSleep, this)); + CHECK(0 == pthread_join(t1, NULL)); + CHECK(0 == pthread_join(t2, NULL)); + CHECK(0 == pthread_join(t3, NULL)); + CHECK(0 == pthread_join(t4, NULL)); + auto diff = time(NULL) - start; + ASSERT_GE(diff, 3); + ASSERT_LE(diff, 5); +} + +} // namespace svr2::util diff --git a/enclave/core/core.cc b/enclave/core/core.cc new file mode 100644 index 0000000..9345227 --- /dev/null +++ b/enclave/core/core.cc @@ -0,0 +1,1689 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include "core/core.h" + +#include + +#include "proto/enclaveconfig.pb.h" +#include "util/macros.h" +#include "env/env.h" +#include "sender/sender.h" +#include "context/context.h" +#include "util/log.h" +#include "util/bytes.h" +#include "util/constant.h" +#include "util/endian.h" +#include "core/internal.h" +#include "metrics/metrics.h" +#include "hmac/hmac.h" +#include "util/hex.h" + +#define IDLOG(x) LOG(x) << "(" << ID().DebugString() << ") " + +namespace svr2::core { + +namespace { + +void LogRaftGroupConfig(const std::string& name, const enclaveconfig::RaftGroupConfig& c) { + LOG(INFO) << "RaftGroupConfig(" << name << "):" + << " min_voting_replicas:" << c.min_voting_replicas() + << " max_voting_replicas:" << c.max_voting_replicas() + << " super_majority:" << c.super_majority() + << " attestation_timeout:" << c.attestation_timeout() + << " db_version:" << c.db_version() + << " simulated:" << c.simulated(); +} + +bool RaftGroupConfigsEqualExceptForGroupID(const enclaveconfig::RaftGroupConfig& a, const enclaveconfig::RaftGroupConfig& b) { + LOG(INFO) << "Comparing group configs:"; + LogRaftGroupConfig("a", a); + LogRaftGroupConfig("b", b); + return + a.min_voting_replicas() == b.min_voting_replicas() && + a.max_voting_replicas() == b.max_voting_replicas() && + a.super_majority() == b.super_majority() && + a.db_version() == b.db_version() && + a.attestation_timeout() == b.attestation_timeout() && + a.simulated() == b.simulated(); +} + +error::Error ValidateRaftGroupConfig(const enclaveconfig::RaftGroupConfig& c) { + if (c.min_voting_replicas() > c.max_voting_replicas()) { return COUNTED_ERROR(Core_RaftGroupConfigMinReplicasGreaterThanMaxReplicas); } + if (c.min_voting_replicas() < 1) { return COUNTED_ERROR(Core_RaftGroupConfigMinReplicasTooSmall); } + if (c.attestation_timeout() < 1) { return COUNTED_ERROR(Core_RaftGroupConfigAttestationTimeoutTooSmall); } + auto d = db::DB::New(c.db_version()); + if (d.get() == nullptr) { return COUNTED_ERROR(Core_DBVersionInvalid); } + return error::OK; +} + +enclaveconfig::EnclaveConfig DefaultEnclaveConfig() { + enclaveconfig::EnclaveConfig def; + def.set_e2e_txn_timeout_ticks(60); + auto raft = def.mutable_raft(); + raft->set_election_ticks(32); + raft->set_heartbeat_ticks(15); + raft->set_replication_chunk_bytes(1<<20); + raft->set_replica_voting_timeout_ticks(60); + raft->set_replica_membership_timeout_ticks(300); + raft->set_log_max_bytes(1<<30); + raft->set_replication_pipeline(4); + return def; +} + +void ReplyWithError(context::Context* ctx, internal::TransactionID tx, error::Error err) { + EnclaveMessage* out = ctx->Protobuf(); + auto resp = out->mutable_h2e_response(); + resp->set_request_id(tx); + resp->set_status(err); + if (err != error::OK) { + LOG(WARNING) << "Responding to host request " << tx << " with error: " << err; + } + sender::Send(*out); +} + +static bool ContainsMe(const peerid::PeerID& me, const raft::ReplicaGroup& group) { + std::string me_str; + me.ToString(&me_str); + for (int i = 0; i < group.replicas_size(); i++) { + auto replica = group.replicas(i); + if (replica.peer_id() == me_str) return true; + } + return false; +} + +} // namespace + +Core::Core(const enclaveconfig::RaftGroupConfig& group_config) : raft_config_template_(group_config), db_version_(group_config.db_version()), db_protocol_(db::DB::New(group_config.db_version())->P()), e2e_txn_id_(0) { +} + +std::pair, error::Error> Core::Create( + context::Context* ctx, + const enclaveconfig::InitConfig& provided_config) { + LOG(INFO) << "Creating core"; + auto config = DefaultEnclaveConfig(); + config.MergeFrom(provided_config.enclave_config()); + error::Error err = error::OK; + LOG(INFO) << "Validating"; + if (error::OK != (err = Core::ValidateConfig(config))) { + LOG(INFO) << "Validation error: " << err; + return std::make_pair(nullptr, err); + } + if (error::OK != (err = ValidateRaftGroupConfig(provided_config.group_config()))) { + LOG(INFO) << "Raft group config validation error: " << err; + return std::make_pair(nullptr, err); + } + LOG(INFO) << "Initializing"; + std::unique_ptr core(new Core(provided_config.group_config())); + if (error::OK != (err = core->Init(ctx, config, provided_config.initial_timestamp_unix_secs()))) { + return std::make_pair(nullptr, err); + } + return std::make_pair(std::move(core), error::OK); +} + +error::Error Core::Init(context::Context* ctx, const enclaveconfig::EnclaveConfig& config, util::UnixSecs initial_timestamp_unix_secs) { + RETURN_IF_ERROR(Core::ValidateConfig(config)); + + // The ClientManager will obtain evidence and endorsements as needed. + LOG(INFO) << "Setting up client DHState"; + auto client_dh = client::ClientManager::NewDHState(); + if (client_dh.get() == nullptr) { + return COUNTED_ERROR(Core_InitClientDHState); + } + + // The PeerManager will create a key pair, set the public key as its ID, and obtain attestation + // evidence and endorsements as needed. + LOG(INFO) << "Setting up peer DHState"; + auto peer_manager = std::make_unique(); + RETURN_IF_ERROR(peer_manager->Init(ctx)); + + LOG(INFO) << "Setting up remaining core"; + { + ACQUIRE_LOCK(config_mu_, ctx, lock_core_config); + enclave_config_ = config; + } + peer_manager_ = std::move(peer_manager); + client_manager_ = std::make_unique(std::move(client_dh)); + clock_.SetLocalTime(initial_timestamp_unix_secs); + peer_manager_->SetPeerAttestationTimestamp(ctx, initial_timestamp_unix_secs, raft_config_template_.attestation_timeout()); + + SendTimestampToAll(ctx); + + return error::OK; +} + +error::Error Core::ValidateConfig(const enclaveconfig::EnclaveConfig& config) { + auto raft = config.raft(); + if (raft.election_ticks() == 0) { return COUNTED_ERROR(Config_ElectionTicks); } + if (raft.heartbeat_ticks() >= raft.election_ticks()) { return COUNTED_ERROR(Config_HeartbeatVsElectionTicks); } + if (raft.replication_chunk_bytes() < (1024)) { return COUNTED_ERROR(Config_ReplicationChunk); } + if (raft.replica_voting_timeout_ticks() <= raft.election_ticks()) { return COUNTED_ERROR(Config_ReplicaTimeout); } + if (raft.replica_membership_timeout_ticks() <= raft.replica_voting_timeout_ticks()) { return COUNTED_ERROR(Config_ReplicaTimeout); } + if (raft.log_max_bytes() < 1024) { return COUNTED_ERROR(Config_LogMaxBytes); } + if (raft.replication_pipeline() <= 0 || raft.replication_pipeline() >= UINT32_MAX) { return COUNTED_ERROR(Config_ReplicationPipeline); } + if (config.e2e_txn_timeout_ticks() < 1) { return COUNTED_ERROR(Config_E2ETransactionTimeout); } + return error::OK; +} + +error::Error Core::ValidateConfigChange(const enclaveconfig::EnclaveConfig& old_config, const enclaveconfig::EnclaveConfig& new_config) { + RETURN_IF_ERROR(ValidateConfig(new_config)); + return error::OK; +} + +enclaveconfig::EnclaveConfig* Core::enclave_config(context::Context* ctx) const { + ACQUIRE_LOCK(config_mu_, ctx, lock_core_config); + auto cfg = ctx->Protobuf(); + cfg->MergeFrom(enclave_config_); + return cfg; +} + +error::Error Core::Receive(context::Context* ctx, const UntrustedMessage& msg) { + switch (msg.inner_case()) { + case UntrustedMessage::kH2ERequest: { + MEASURE_CPU(ctx, cpu_core_host_msg); + COUNTER(core, host_requests_received)->Increment(); + return HandleHostToEnclave(ctx, msg.h2e_request()); + } + case UntrustedMessage::kTimerTick: + COUNTER(core, timer_ticks_received)->Increment(); + HandleTimerTick(ctx, msg.timer_tick()); + return error::OK; + case UntrustedMessage::kResetPeer:{ + peerid::PeerID peer_id; + RETURN_IF_ERROR(peer_id.FromString(msg.reset_peer().peer_id())); + return peer_manager_->ResetPeer(ctx, peer_id); + } + case UntrustedMessage::kPeerMessage: { + MEASURE_CPU(ctx, cpu_core_peer_msg); + COUNTER(core, peer_msgs_received)->Increment(); + return HandlePeerMessage(ctx, msg); + } + default: + COUNTER(core, invalid_msgs_received)->Increment(); + return error::General_Unimplemented; + } +} + +error::Error Core::HandleHostToEnclave(context::Context* ctx, const HostToEnclaveRequest& msg) { + internal::TransactionID tx = msg.request_id(); + if (tx == 0) { + return COUNTED_ERROR(Core_HostToEnclaveTransactionID); + } + IDLOG(DEBUG) << "request " << tx << " is " << msg.inner_case(); + switch (msg.inner_case()) { + case HostToEnclaveRequest::kNewClient: { + MEASURE_CPU(ctx, cpu_core_client_msg); + HandleNewClient(ctx, msg.new_client(), tx); + } return error::OK; + case HostToEnclaveRequest::kExistingClient: { + MEASURE_CPU(ctx, cpu_core_client_msg); + error::Error err = HandleExistingClient(ctx, msg.existing_client(), tx); + // We never let client errors get us down, but we do close down clients + // with abandon if they encounter errors. + if (err != error::OK) { + client_manager_->RemoveClient(ctx, msg.existing_client().client_id()); + ReplyWithError(ctx, tx, err); + } + } return error::OK; // return OK, even if we closed the client. + case HostToEnclaveRequest::kCloseClient: { + MEASURE_CPU(ctx, cpu_core_client_msg); + client_manager_->RemoveClient(ctx, msg.close_client().client_id()); + ReplyWithError(ctx, tx, error::OK); + } return error::OK; + case HostToEnclaveRequest::kCreateNewRaftGroup: { + HandleCreateNewRaftGroupRequest(ctx, tx); + } return error::OK; + case HostToEnclaveRequest::kJoinRaft: { + HandleJoinRaft(ctx, msg.join_raft(), tx); + } return error::OK; + case HostToEnclaveRequest::kPingPeer: { + peerid::PeerID peer_id; + error::Error peer_id_err = peer_id.FromString(msg.ping_peer().peer_id()); + if (peer_id_err != error::OK) { + ReplyWithError(ctx, tx, peer_id_err); + return error::OK; + } + auto req = ctx->Protobuf(); + req->set_ping(true); + SendE2ETransaction(ctx, peer_id, *req, true, + [tx](context::Context* ctx, error::Error err, const e2e::TransactionResponse* resp) { + ReplyWithError(ctx, tx, err); + }); + } return error::OK; + case HostToEnclaveRequest::kRefreshAttestation: { + error::Error peer_err = peer_manager_->RefreshAttestation(ctx); + error::Error client_err = HandleRefreshAttestation(ctx, msg.refresh_attestation().rotate_client_key()); + ReplyWithError(ctx, tx, peer_err != error::OK ? peer_err : client_err); + } return error::OK; + case HostToEnclaveRequest::kRequestVoting: { + ACQUIRE_LOCK(raft_.mu, ctx, lock_core_raft); + if (raft_.state != svr2::RAFTSTATE_LOADED_PART_OF_GROUP) { + ReplyWithError(ctx, tx, COUNTED_ERROR(Core_RaftState)); + } else if (raft_.loaded.raft->voting()) { + ReplyWithError(ctx, tx, COUNTED_ERROR(Core_VotingRequestedForVotingMember)); + } else if (!raft_.loaded.raft->leader().has_value()) { + ReplyWithError(ctx, tx, COUNTED_ERROR(Core_LeaderUnknown)); + } else { + auto txn_req = ctx->Protobuf(); + txn_req->set_raft_voting_request(true); + SendE2ETransaction(ctx, *raft_.loaded.raft->leader(), *txn_req, true, + [tx](context::Context* ctx, error::Error err, const e2e::TransactionResponse* resp) { + ReplyWithError(ctx, tx, err); + }); + } + } return error::OK; + case HostToEnclaveRequest::kGetEnclaveStatus: { + auto out = ctx->Protobuf(); + auto resp = out->mutable_h2e_response(); + resp->set_request_id(tx); + auto [replica_status, err] = HandleGetEnclaveStatus(ctx); + if (err != error::OK) { + ReplyWithError(ctx, tx, err); + } else { + resp->mutable_get_enclave_status_reply()->MergeFrom(replica_status); + sender::Send(*out); + } + } return error::OK; + case HostToEnclaveRequest::kRequestMetrics: { + env::environment->UpdateEnvStats(); + EnclaveMessage* out = ctx->Protobuf(); + auto resp = out->mutable_h2e_response(); + resp->set_request_id(tx); + *resp->mutable_metrics_reply() = metrics::AllAsPB(); + sender::Send(*out); + } return error::OK; + case HostToEnclaveRequest::kDatabaseRequest: + return HandleHostDatabaseRequest(ctx, tx, msg.database_request()); + case HostToEnclaveRequest::kReconfigure: { + auto err = HandleReconfigure(ctx, tx, msg.reconfigure()); + ReplyWithError(ctx, tx, err); + } return error::OK; + case HostToEnclaveRequest::kSetLogLevel: { + if (msg.set_log_level() >= ::svr2::enclaveconfig::LOG_LEVEL_MAX) { + ReplyWithError(ctx, tx, error::Core_InvalidLogLevel); + } else { + util::SetLogLevel(msg.set_log_level()); + ReplyWithError(ctx, tx, error::OK); + } + } return error::OK; + case HostToEnclaveRequest::kRelinquishLeadership: { + HandleRelinquishLeadership(ctx, tx); + } return error::OK; + case HostToEnclaveRequest::kRequestRemoval: { + HandleHostRequestedRaftRemoval(ctx, tx); + } return error::OK; + case HostToEnclaveRequest::kHashes: { + auto err = HandleHostHashes(ctx, tx); + if (err != error::OK) { ReplyWithError(ctx, tx, err); } + } return error::OK; + default: + return error::General_Unimplemented; + } +} + +void Core::HandleNewClient(context::Context* ctx, const NewClientRequest& msg, internal::TransactionID tx) { + auto [client, err] = client_manager_->NewClient(ctx, msg.client_authenticated_id()); + if (err != error::OK) { + ReplyWithError(ctx, tx, err); + COUNTER(core, new_client_failure)->Increment(); + return; + } + auto out = ctx->Protobuf(); + auto resp = out->mutable_h2e_response(); + resp->set_request_id(tx); + auto new_client = resp->mutable_new_client_reply(); + new_client->set_client_id(client->ID()); + *new_client->mutable_handshake_start() = client->MovedHandshakeStart(); + sender::Send(*out); + COUNTER(core, new_client_success)->Increment(); +} + +error::Error Core::HandleExistingClient(context::Context* ctx, const ExistingClientRequest& msg, internal::TransactionID tx) { + client::ClientID client_id = msg.client_id(); + client::Client* c = client_manager_->GetClient(ctx, client_id); + if (c == nullptr) { + return COUNTED_ERROR(Core_ClientNotFound); + } + if (c->Handshaking()) { + auto [handshake, err] = c->FinishHandshake(ctx, msg.data()); + RETURN_IF_ERROR(err); + auto out = ctx->Protobuf(); + auto resp = out->mutable_h2e_response(); + resp->set_request_id(tx); + resp->mutable_existing_client_reply()->set_data(handshake); + sender::Send(*out); + return error::OK; + } + auto request = db_protocol_->RequestPB(ctx); + RETURN_IF_ERROR(c->DecryptRequest(ctx, msg.data(), request)); + auto [log, err] = db_protocol_->LogPBFromRequest(ctx, std::move(*request), c->authenticated_id()); + RETURN_IF_ERROR(err); + RETURN_IF_ERROR(db_protocol_->ValidateClientLog(*log)); + std::string serialized; + if (!log->SerializeToString(&serialized)) { + return COUNTED_ERROR(Core_SerializeClientLog); + } + return RaftWriteLogTransaction(ctx, serialized, ClientLogTransaction(ctx, client_id, tx)); +} + +void Core::HandleCreateNewRaftGroupRequest(context::Context* ctx, internal::TransactionID tx) { + LOG(INFO) << "Attempting to create new raft group"; + ACQUIRE_LOCK(raft_.mu, ctx, lock_core_raft); + if (raft_.state != svr2::RAFTSTATE_NO_STATE) { + ReplyWithError(ctx, tx, COUNTED_ERROR(Core_RaftState)); + } + enclaveconfig::RaftGroupConfig cfg = raft_config_template_; + uint8_t group_id_bytes[8]; + error::Error gid_err = env::environment->RandomBytes(group_id_bytes, sizeof(group_id_bytes)); + if (gid_err != error::OK) { + ReplyWithError(ctx, tx, gid_err); + } + raft::GroupId group_id = util::BigEndian64FromBytes(group_id_bytes); + cfg.set_group_id(group_id); + cfg.set_db_version(db_version_); + + raft_.state = svr2::RAFTSTATE_LOADED_PART_OF_GROUP; + enclaveconfig::RaftConfig raft_config = enclave_config(ctx)->raft(); + raft_.loaded = { + .group_config = cfg, + .raft = std::make_unique( + group_id, + peer_manager_->ID(), + raft::Membership::First(peer_manager_->ID()), + std::make_unique(raft_config.log_max_bytes()), + raft_config, + false, + cfg.super_majority()), // committed_log + .db = db::DB::New(db_version_), + .db_last_applied_log = 0, + }; + GAUGE(core, last_index_applied_to_db)->Set(0); + RaftStep(ctx); + ReplyWithError(ctx, tx, error::OK); +} + +void Core::HandleJoinRaft(context::Context* ctx, const JoinRaftRequest& msg, internal::TransactionID tx) { + ACQUIRE_LOCK(raft_.mu, ctx, lock_core_raft); + if (raft_.state == svr2::RAFTSTATE_LOADED_PART_OF_GROUP) { + ReplyWithError(ctx, tx, COUNTED_ERROR(Core_RaftState)); + return; + } + peerid::PeerID peer; + error::Error peer_err = peer.FromString(msg.peer_id()); + if (peer_err != error::OK) { + ReplyWithError(ctx, tx, COUNTED_ERROR(Core_RaftState)); + return; + } + raft_.ClearState(); + raft_.state = svr2::RAFTSTATE_WAITING_FOR_FIRST_CONNECTION; + raft_.waiting_for_first_connection = { + .peer = peer, + .join_tx = tx, + }; + + switch (peer_manager_->PeerState(ctx, peer)) { + case PEER_CONNECTED: + JoinRaftFromFirstPeer(ctx); + break; + case PEER_CONNECTING: + break; + default: + peer_manager_->ConnectToPeer(ctx, peer); + break; + } +} + +void Core::JoinRaftFromFirstPeer(context::Context* ctx) { + CHECK(raft_.state == svr2::RAFTSTATE_WAITING_FOR_FIRST_CONNECTION); + internal::TransactionID tx = raft_.waiting_for_first_connection.join_tx; + peerid::PeerID peer = raft_.waiting_for_first_connection.peer; + IDLOG(VERBOSE) << "requesting to join raft from peer " << peer << " tx=" << tx; + auto req = ctx->Protobuf(); + req->set_get_raft(true); + SendE2ETransaction( + ctx, peer, *req, true, + [this, tx, peer](context::Context* ctx, error::Error err, const e2e::TransactionResponse* resp){ + if (err != error::OK) { + ReplyWithError(ctx, tx, err); + return; + } + ACQUIRE_LOCK(raft_.mu, ctx, lock_core_raft); + + // We cleared the RaftState before sending this request and we will only proceed with this + // callback if no intermediate action has changed the state. + if (raft_.state != svr2::RAFTSTATE_WAITING_FOR_FIRST_CONNECTION) { + ReplyWithError(ctx, tx, COUNTED_ERROR(Core_RaftState)); + return; + } + // Since the state is NO_STATE we are guaranteed that the raft_ has default values (no + // actions change raft_ with out changing raft_state). + + auto got = resp->get_raft(); + enclaveconfig::RaftGroupConfig group_config_equality_check = got.group_config(); + group_config_equality_check.clear_group_id(); + if (!RaftGroupConfigsEqualExceptForGroupID(raft_config_template_, group_config_equality_check)) { + ReplyWithError(ctx, tx, COUNTED_ERROR(Core_GroupConfigMismath)); + return; + } + if (got.replica_group().replicas_size() == 0) { + ReplyWithError(ctx, tx, COUNTED_ERROR(Core_ReceivedEmptyReplicaGroup)); + return; + } + auto [mem, mem_err] = raft::Membership::FromProto(got.replica_group()); + if (mem_err != error::OK) { + ReplyWithError(ctx, tx, mem_err); + return; + } + for (int i = 0; i < got.replica_group().replicas_size(); i++) { + peerid::PeerID p; + const auto& replica = got.replica_group().replicas(i); + err = p.FromString(replica.peer_id()); + if (err != error::OK) { + ReplyWithError(ctx, tx, err); + return; + } + err = peer_manager_->MaybeConnectToPeer(ctx, p); + if (err != error::OK) { + ReplyWithError(ctx, tx, err); + return; + } + } + + LOG(INFO) << "received raft information, switching to loading state and starting replication"; + raft_.ClearState(); + raft_.state = svr2::RAFTSTATE_LOADING; + raft_.loading = { + .group_config = got.group_config(), + .replica_group = got.replica_group(), + .log = std::make_unique(enclave_config(ctx)->raft().log_max_bytes()), + .db = db::DB::New(db_version_), + .mem = std::move(mem), + .load_from = peer, + .join_tx = tx, + }; + + // Reset client attestation based on new group config. + if (error::OK != (err = client_manager_->RefreshAttestation(ctx, raft_.loading.group_config))) { + ReplyWithError(ctx, tx, err); + return; + } + + RequestRaftReplication(ctx); + }); +} + +void Core::RequestRaftReplication(context::Context* ctx) { + if (raft_.state != svr2::RAFTSTATE_LOADING) { + IDLOG(WARNING) << "RequestRaftReplication called while state is " << raft_.state; + return; + } + if (!raft_.loading.started) { + size_t connected = 0; + const auto& voting_replicas = raft_.loading.mem->voting_replicas(); + for (auto peer : peer_manager_->ConnectedPeers(ctx)) { + if (voting_replicas.count(peer)) { + connected++; + } + } + size_t quorum = raft::Raft::quorum_size( + voting_replicas.size(), raft_.loading.group_config.super_majority()); + if (connected < quorum) { + IDLOG(VERBOSE) << "Still waiting for peer connections before starting load, have " << connected << ", need " << quorum; + return; + } + raft_.loading.started = true; + } + uint8_t repl_id[8]; + env::environment->RandomBytes(repl_id, sizeof(repl_id)); + raft_.loading.replication_id = util::BigEndian64FromBytes(repl_id); + internal::TransactionID tx = raft_.loading.join_tx; + const peerid::PeerID& from = raft_.loading.load_from; + + auto req = ctx->Protobuf(); + auto repl = req->mutable_replicate_state(); + repl->set_group_id(raft_.loading.group_config.group_id()); + repl->set_replication_id(raft_.loading.replication_id); + + IDLOG(VERBOSE) << "requesting replication from " << from; + SendE2ETransaction(ctx, from, *req, false /* no timeout */, + [this, from, tx](context::Context* ctx, error::Error err, const e2e::TransactionResponse* resp) { + if (err != error::OK) { + // We've failed to replicate state. For now, revert back to no state. + LOG(ERROR) << "failed to replicate state from " << from << ": " << err; + ReplyWithError(ctx, tx, err); + return; + } + IDLOG(INFO) << "finished replicating database, fully loaded"; + ACQUIRE_LOCK(raft_.mu, ctx, lock_core_raft); + PromoteRaftToLoaded(ctx); + }); +} + +void Core::PromoteRaftToLoaded(context::Context* ctx) { + internal::Loading loading = std::move(raft_.loading); + raft_.ClearState(); + raft_.state = svr2::RAFTSTATE_LOADED_REQUESTING_MEMBERSHIP; + raft::LogIdx db_last_applied_log = loading.log->last_idx(); + raft_.loaded = { + .group_config = loading.group_config, + .raft = std::make_unique( + loading.group_config.group_id(), + peer_manager_->ID(), + std::move(loading.mem), + std::move(loading.log), + enclave_config(ctx)->raft(), + true, + loading.group_config.super_majority()), // committed_log + .db = std::move(loading.db), + .db_last_applied_log = db_last_applied_log, + }; + GAUGE(core, last_index_applied_to_db)->Set(db_last_applied_log); + RaftRequestMembership(ctx, loading.join_tx); +} + +void Core::RaftRequestMembership(context::Context* ctx, internal::TransactionID tx) { + // Never request membership unless in the correct state. + CHECK(raft_.state == svr2::RAFTSTATE_LOADED_REQUESTING_MEMBERSHIP); + // We could be tricky and try to find out who the leader is. Instead, we'll + // just send our request to every member. Note that this will cause error + // Raft_AppendEntryNotLeader (5004) to appear in the logs + auto req = ctx->Protobuf(); + req->set_raft_membership_request(true); + // Set a timeout for if we fail to do this. + timeout::Cancel cancel = timeout_.SetTimeout(ctx, enclave_config(ctx)->e2e_txn_timeout_ticks(), [this, tx](context::Context* ctx) { + ACQUIRE_LOCK(raft_.mu, ctx, lock_core_raft); + if (raft_.state == svr2::RAFTSTATE_LOADED_REQUESTING_MEMBERSHIP) { + RaftRequestMembership(ctx, tx); + } + }); + + for (const auto& peer : raft_.loaded.raft->peers()) { + IDLOG(VERBOSE) << "requesting raft membership from " << peer; + SendE2ETransaction(ctx, peer, *req, true, + [this, tx, cancel](context::Context* ctx, error::Error err, const e2e::TransactionResponse* resp) { + if (err != error::OK) { + LOG(WARNING) << "Error requesting raft membership: " << err; + return; + } + AddLogTransaction(ctx, resp->raft_membership_response(), [this, tx, cancel]( + context::Context* ctx, + error::Error err, + const raft::LogEntry* entry, + const db::DB::Response* response) { + // HandleRaftMembershipChange does the actual state changes, this + // just tells our requester that we've succeeded. + if (err == error::OK) { + timeout_.CancelTimeout(ctx, cancel); + } + ReplyWithError(ctx, tx, err); + }); + }); + } +} + +error::Error Core::HandleRefreshAttestation(context::Context* ctx, bool rotate_key) { + enclaveconfig::RaftGroupConfig config; + { // Copy current config out of Raft state. + ACQUIRE_LOCK(raft_.mu, ctx, lock_core_raft); + switch (raft_.state) { + case svr2::RAFTSTATE_LOADING: + config.MergeFrom(raft_.loading.group_config); + break; + case svr2::RAFTSTATE_LOADED_REQUESTING_MEMBERSHIP: + case svr2::RAFTSTATE_LOADED_PART_OF_GROUP: + config.MergeFrom(raft_.loaded.group_config); + break; + default: + return COUNTED_ERROR(Core_RefreshClientAttestationWithoutRaftConfig); + } + } + return rotate_key + ? client_manager_->RotateKeyAndRefreshAttestation(ctx, config) + : client_manager_->RefreshAttestation(ctx, config); +} + +std::pair Core::HandleGetEnclaveStatus(context::Context* ctx) const { + EnclaveReplicaStatus result; + ACQUIRE_LOCK(raft_.mu, ctx, lock_core_raft); + result.set_raft_state(raft_.state); + + auto peers = peer_manager_->AllPeers(ctx); + peers.insert(ID()); + peerid::PeerID leader; + std::set all_replicas; + std::set voting_replicas; + if(raft_.state == svr2::RAFTSTATE_LOADED_PART_OF_GROUP) { + leader = raft_.loaded.raft->leader().value_or(peerid::PeerID()); + auto& membership = raft_.loaded.raft->membership(); + all_replicas = membership.all_replicas(); + voting_replicas = membership.voting_replicas(); + } + for (auto peer_id : peers) { + auto peer_status = result.add_peers(); + peer_status->set_peer_id(peer_id.AsString()); + peer_status->set_in_raft(all_replicas.count(peer_id)); + peer_status->set_is_voting(voting_replicas.count(peer_id) > 0); + peer_status->set_is_leader(peer_id == leader); + peer_manager_->PeerStatus(ctx, peer_id, peer_status->mutable_connection_status()); + peer_status->set_me(peer_id == ID()); + + if (leader == ID() && peer_id != ID()) { + auto err = raft_.loaded.raft->FollowerReplicationStatus(peer_id, peer_status->mutable_replication_status()); + if(err != error::OK) { + return std::make_pair(result, err); + } + } + } + return std::make_pair(result, error::OK); +} + +error::Error Core::HandleHostDatabaseRequest(context::Context* ctx, internal::TransactionID tx, const DatabaseRequest& req) { + auto cli_req = db_protocol_->RequestPB(ctx); + if (!cli_req->ParseFromString(req.request())) { + return COUNTED_ERROR(Core_DeserializeHostDatabaseRequest); + } + auto [log, err] = db_protocol_->LogPBFromRequest(ctx, std::move(*cli_req), req.authenticated_id()); + RETURN_IF_ERROR(err); + std::string serialized; + if (!log->SerializeToString(&serialized)) { + return COUNTED_ERROR(Core_SerializeClientLog); + } + return RaftWriteLogTransaction(ctx, serialized, [tx]( + context::Context* ctx, + error::Error err, + const raft::LogEntry* entry, + const db::DB::Response* resp) { + if (err == error::OK) { + COUNTER(core, host_delete_success)->Increment(); + } else { + COUNTER(core, host_delete_failure)->Increment(); + } + ReplyWithError(ctx, tx, err); + }); +} + +error::Error Core::HandleReconfigure(context::Context* ctx, internal::TransactionID tx, const enclaveconfig::EnclaveConfig& req) { + auto new_config = DefaultEnclaveConfig(); + new_config.MergeFrom(req); + { + ACQUIRE_LOCK(config_mu_, ctx, lock_core_config); + RETURN_IF_ERROR(ValidateConfigChange(enclave_config_, new_config)); + enclave_config_ = new_config; + } + { + ACQUIRE_LOCK(raft_.mu, ctx, lock_core_raft); + if (raft_.state == svr2::RAFTSTATE_LOADED_PART_OF_GROUP + || raft_.state == svr2::RAFTSTATE_LOADED_REQUESTING_MEMBERSHIP) { + raft_.loaded.raft->Reconfigure(new_config.raft()); + } + } + return error::OK; +} + +void Core::HandleRelinquishLeadership(context::Context* ctx, internal::TransactionID tx) { + ACQUIRE_LOCK(raft_.mu, ctx, lock_core_raft); + if (raft_.state != RAFTSTATE_LOADED_PART_OF_GROUP || !raft_.loaded.raft->is_leader()) { + // We're already not the leader. + ReplyWithError(ctx, tx, error::OK); + return; + } + raft_.loaded.raft->RelinquishLeadership(ctx); + // If we succeed in relinquishing leadership, then the log that's one past the + // last one we have will have a term one greater than the most recent term. + // Set up a watcher for that. + raft::LogLocation loc; + loc.set_idx(raft_.loaded.raft->log().next_idx()); + loc.set_term(raft_.loaded.raft->log().last_term() + 1); + AddLogTransaction(ctx, loc, [tx]( + context::Context* ctx, + error::Error err, + const raft::LogEntry* entry, + const db::DB::Response* resp) { + ReplyWithError(ctx, tx, err); + }); + RaftStep(ctx); +} + +void Core::HandleHostRequestedRaftRemoval(context::Context* ctx, internal::TransactionID tx) { + LOG(VERBOSE) << "HandleHostRequestedRaftRemoval"; + ACQUIRE_LOCK(raft_.mu, ctx, lock_core_raft); + if (raft_.state != RAFTSTATE_LOADED_PART_OF_GROUP) { + ReplyWithError(ctx, tx, COUNTED_ERROR(Core_RaftState)); + } else if (raft_.loaded.raft->is_leader()) { + ReplyWithError(ctx, tx, COUNTED_ERROR(Core_LeaderRemovingSelf)); + } else if (!raft_.loaded.raft->leader().has_value()) { + ReplyWithError(ctx, tx, COUNTED_ERROR(Core_LeaderUnknown)); + } else { + auto req = ctx->Protobuf(); + req->set_raft_removal_request(true); + auto peer = *raft_.loaded.raft->leader(); + SendE2ETransaction(ctx, peer, *req, true, [peer, tx](context::Context* ctx, error::Error err, const e2e::TransactionResponse* resp) { + LOG(INFO) << "RaftRemovalRequest to " << peer << ": " << err; + ReplyWithError(ctx, tx, err); + }); + } +} + +error::Error Core::HandleHostHashes(context::Context* ctx, internal::TransactionID tx) { + ACQUIRE_LOCK(raft_.mu, ctx, lock_core_raft); + if (raft_.state != RAFTSTATE_LOADED_PART_OF_GROUP && + raft_.state != RAFTSTATE_LOADED_REQUESTING_MEMBERSHIP) { + return COUNTED_ERROR(Core_RaftState); + } + auto db_hash = raft_.loaded.db->Hash(ctx); + auto out = ctx->Protobuf(); + auto resp = out->mutable_h2e_response(); + resp->set_request_id(tx); + auto hashes = resp->mutable_hashes(); + hashes->mutable_db_hash()->resize(32, ' '); + std::copy(db_hash.begin(), db_hash.end(), hashes->mutable_db_hash()->data()); + hashes->set_commit_idx(raft_.loaded.db_last_applied_log); + auto log = raft_.loaded.raft->log().At(raft_.loaded.db_last_applied_log).Entry(); + if (log == nullptr) { + return COUNTED_ERROR(Core_LogNotFoundAtCommitIndex); + } + hashes->set_commit_hash_chain(log->hash_chain()); + sender::Send(*out); + return error::OK; +} + +void Core::HandleTimerTick(context::Context* ctx, const TimerTick& tick) { + MEASURE_CPU(ctx, cpu_core_timer_tick); + auto time = tick.new_timestamp_unix_secs(); + clock_.SetLocalTime(time); + GAUGE(core, current_local_time)->Set(time); + MaybeUpdateGroupTime(ctx); + timeout_.TimerTick(ctx); + ACQUIRE_LOCK(raft_.mu, ctx, lock_core_raft); + if (raft_.state == svr2::RAFTSTATE_LOADED_PART_OF_GROUP) { + ConnectToRaftMembers(ctx); + raft_.loaded.raft->TimerTick(ctx); + if (raft_.loaded.raft->is_leader()) { + raft::ReplicaGroup* next = NextReplicaGroup(ctx); + if (next != nullptr) { + auto [loc, err] = raft_.loaded.raft->ReplicaGroupChange(ctx, *next); + // We expect errors to occur here, in cases where for example an existing + // replica group change is already in progress, etc. + LOG(INFO) << "attempt to change replica group returned " << err; + } + } + RaftStep(ctx); + } +} + +void Core::MaybeUpdateGroupTime(context::Context* ctx) { + std::set peers = peer_manager_->ConnectedPeers(ctx); + { + ACQUIRE_LOCK(raft_.mu, ctx, lock_core_raft); + switch (raft_.state) { + case RAFTSTATE_LOADED_PART_OF_GROUP: + case RAFTSTATE_LOADED_REQUESTING_MEMBERSHIP: + peers = raft_.loaded.raft->membership().voting_replicas(); + break; + case RAFTSTATE_LOADING: + peers = raft_.loading.mem->voting_replicas(); + break; + default: + break; + } + } + auto ts = clock_.GetTime(ctx, peers); + GAUGE(core, current_groupclock_time)->Set(ts); + peer_manager_->SetPeerAttestationTimestamp(ctx, ts, raft_config_template_.attestation_timeout()); +} + +void Core::ConnectToRaftMembers(context::Context* ctx) { + const auto& membership = raft_.loaded.raft->membership(); + for (auto peer : membership.all_replicas()) { + if (peer == ID() || peer < ID()) { + continue; + } + auto err = peer_manager_->MaybeConnectToPeer(ctx, peer); + if (err != error::OK) { + IDLOG(INFO) << "Requesting connection to detected disconnected peer " << peer << " failed: " << err; + } + } +} + +raft::ReplicaGroup* Core::NextReplicaGroup(context::Context* ctx) { + if (raft_.state != svr2::RAFTSTATE_LOADED_PART_OF_GROUP) { return nullptr; } + raft::Raft* r = raft_.loaded.raft.get(); + if (!r->leader()) { return nullptr; } + // See if we can add a voting member to increase our total. + const raft::Membership& m = r->membership(); + auto out = ctx->Protobuf(); + *out = m.AsProto(); + // Look for an existing replica to promote to voting. + if (m.voting_replicas().size() < raft_.loaded.group_config.max_voting_replicas() && + m.all_replicas().size() > m.voting_replicas().size()) { + std::string peer_id = ""; + util::Ticks min = util::InvalidTicks; + for (const auto& peer : m.all_replicas()) { + util::Ticks last_seen = r->last_seen_ticks(peer); + if (last_seen < min && m.voting_replicas().count(peer) == 0) { + peer_id = peer.AsString(); + min = last_seen; + } + } + if (peer_id != "" && min < r->config().election_ticks()) { + // We've found a peer that's non-voting and that's responded within the last + // election timeout. Promote them. + for (int i = 0; i < out->replicas_size(); i++) { + if (out->replicas(i).peer_id() == peer_id) { + out->mutable_replicas(i)->set_voting(true); + return out; + } + } + } + } + // Look for an existing voting replica to demote. + if (m.voting_replicas().size() > raft_.loaded.group_config.min_voting_replicas()) { + for (const auto& peer : m.voting_replicas()) { + util::Ticks last_seen = r->last_seen_ticks(peer); + if (last_seen != util::InvalidTicks && last_seen > r->config().replica_voting_timeout_ticks()) { + std::string peer_id = peer.AsString(); + for (int i = 0; i < out->replicas_size(); i++) { + if (out->replicas(i).peer_id() == peer_id) { + out->mutable_replicas(i)->set_voting(false); + return out; + } + } + } + } + } + // Look for non-voting replicas to remove. + if (m.all_replicas().size() > m.voting_replicas().size()) { + for (const auto& peer : m.all_replicas()) { + if (m.voting_replicas().count(peer)) { continue; } + util::Ticks last_seen = r->last_seen_ticks(peer); + if (last_seen != util::InvalidTicks && last_seen > r->config().replica_membership_timeout_ticks()) { + const std::string peer_id = peer.AsString(); + auto it = std::find_if(out->replicas().begin(), out->replicas().end(), [&peer_id](auto& replica) { + return replica.peer_id() == peer_id; + }); + if (it != out->replicas().end()) { + out->mutable_replicas()->erase(it); + return out; + } + } + } + } + return nullptr; +} + +error::Error Core::HandlePeerMessage(context::Context* ctx, const UntrustedMessage& msg) { + const auto remote_msg = msg.peer_message(); + // Parsing the peer ID should always succeed, because the peer manager already did it once. + peerid::PeerID from; + CHECK(error::OK == from.FromString(remote_msg.peer_id())); + // If these are created, they will be so within the arena, so they'll + // be cleaned up when it falls out of scope. + e2e::EnclaveToEnclaveMessage* decoded = nullptr; + auto err = peer_manager_->RecvFromPeer(ctx, remote_msg, &decoded); + if (err != error::OK) { + LOG(WARNING) << "Failed to receive message from " << from << " of type " << remote_msg.inner_case() << ": " << err; + return err; + } + if (decoded == nullptr) { + return error::OK; + } + return HandleE2E(ctx, from, *decoded); +} + +error::Error Core::HandleE2E(context::Context* ctx, const peerid::PeerID& from, const e2e::EnclaveToEnclaveMessage& msg) { + switch (msg.inner_case()) { + case e2e::EnclaveToEnclaveMessage::kConnected: + HandlePeerConnect(ctx, from); + return error::OK; + case e2e::EnclaveToEnclaveMessage::kRaftMessage: { + ACQUIRE_LOCK(raft_.mu, ctx, lock_core_raft); + MEASURE_CPU(ctx, cpu_core_raft_msg); + if (raft_.state != svr2::RAFTSTATE_LOADED_PART_OF_GROUP && + raft_.state != svr2::RAFTSTATE_LOADED_REQUESTING_MEMBERSHIP) { + return COUNTED_ERROR(Core_RaftState); + } + raft_.loaded.raft->Receive(ctx, msg.raft_message(), from); + RaftStep(ctx); + } return error::OK; + case e2e::EnclaveToEnclaveMessage::kTransactionRequest: { + MEASURE_CPU(ctx, cpu_core_e2e_txn_req); + return Core::HandleE2ETransaction(ctx, from, msg.transaction_request()); + } + case e2e::EnclaveToEnclaveMessage::kTransactionResponse: { + MEASURE_CPU(ctx, cpu_core_e2e_txn_resp); + const auto& txn_resp = msg.transaction_response(); + ACQUIRE_NAMED_LOCK(lock, e2e_txn_mu_, ctx, lock_core_e2e_txns); + auto f = outstanding_e2e_transactions_.find(txn_resp.request_id()); + if (f == outstanding_e2e_transactions_.end()) { + LOG(VERBOSE) << "received response to e2e transaction that has no callback " << txn_resp.request_id(); + return error::OK; + } + auto callback = std::move(f->second); + IDLOG(VERBOSE) << "received response to e2e transaction " << f->first << ": error=" << msg.transaction_response().status(); + outstanding_e2e_transactions_.erase(f); + timeout_.CancelTimeout(ctx, callback.timeout_cancel); + lock.unlock(); + callback.callback(ctx, msg.transaction_response().status(), &msg.transaction_response()); + } return error::OK; + default: + return error::General_Unimplemented; + } +} + +void Core::HandlePeerConnect(context::Context* ctx, const peerid::PeerID& from) { + IDLOG(INFO) << "successfully established connection to " << from; + + // On each connect, immediately send our most current (local) timestamp. + SendTimestamp(ctx, from, clock_.GetLocalTime()); + + ACQUIRE_LOCK(raft_.mu, ctx, lock_core_raft); + switch (raft_.state) { + case svr2::RAFTSTATE_LOADING: + if (!raft_.loading.started) { + // If we don't have an in-flight request to load stuff and we've connected + // to a new peer, see if the connection to that new peer is enough to get + // us started. + RequestRaftReplication(ctx); + } + break; + case svr2::RAFTSTATE_WAITING_FOR_FIRST_CONNECTION: + if (from == raft_.waiting_for_first_connection.peer) { + JoinRaftFromFirstPeer(ctx); + } + break; + case svr2::RAFTSTATE_LOADED_REQUESTING_MEMBERSHIP: + case svr2::RAFTSTATE_LOADED_PART_OF_GROUP: + raft_.loaded.raft->ResetPeer(ctx, from); + break; + default: + break; + } +} + +error::Error Core::HandleE2ETransaction(context::Context* ctx, const peerid::PeerID& from, const e2e::TransactionRequest& msg) { + auto e2e_resp = ctx->Protobuf(); + auto txn_resp = e2e_resp->mutable_transaction_response(); + txn_resp->set_request_id(msg.request_id()); + error::Error err = error::OK; + switch (msg.inner_case()) { + case e2e::TransactionRequest::kPing: + txn_resp->set_status(error::OK); + break; + case e2e::TransactionRequest::kGetRaft: { + LOG(VERBOSE) << "GetRaft"; + auto out = txn_resp->mutable_get_raft(); + ACQUIRE_LOCK(raft_.mu, ctx, lock_core_raft); + if (raft_.state != svr2::RAFTSTATE_LOADED_PART_OF_GROUP) { + err = COUNTED_ERROR(Core_RaftState); + break; + } + *out->mutable_group_config() = raft_.loaded.group_config; + *out->mutable_replica_group() = raft_.loaded.raft->membership().AsProto(); + } break; + case e2e::TransactionRequest::kReplicateState: + // The response to ReplicateStateRequest will be sent async, not within this transaction call. + return HandleReplicateStateRequest(ctx, from, msg); + case e2e::TransactionRequest::kReplicateStatePush: { + err = HandleReplicateStatePush(ctx, msg.replicate_state_push()); + } break; + case e2e::TransactionRequest::kRaftMembershipRequest: { + err = HandleRequestRaftMembership(ctx, from, txn_resp); + } break; + case e2e::TransactionRequest::kRaftVotingRequest: { + err = HandleRequestRaftVoting(ctx, from, txn_resp); + } break; + case e2e::TransactionRequest::kRaftWrite: { + err = HandleRaftWrite(ctx, msg.raft_write(), txn_resp); + } break; + case e2e::TransactionRequest::kNewTimestampUnixSecs: { + HandleNewTimestamp(ctx, from, msg.new_timestamp_unix_secs()); + } break; + case e2e::TransactionRequest::kRaftRemovalRequest: + // The response to RaftRemovalRequest will be sent async. + return HandlePeerRequestedRaftRemoval(ctx, from, msg.request_id()); + default: + LOG(WARNING) << "unknown e2e transaction type " << msg.inner_case(); + err = error::General_Unimplemented; + break; + } + if (err != error::OK || txn_resp->inner_case() == e2e::TransactionResponse::INNER_NOT_SET) { + return SendE2EError(ctx, from, msg.request_id(), err); + } + return peer_manager_->SendToPeer(ctx, from, *e2e_resp); +} + +error::Error Core::HandleReplicateStateRequest(context::Context* ctx, const peerid::PeerID& target, const e2e::TransactionRequest& req) { + const e2e::ReplicateStateRequest& msg = req.replicate_state(); + LOG(VERBOSE) << "HandleReplicateStateRequest"; + ACQUIRE_LOCK(raft_.mu, ctx, lock_core_raft); + if (raft_.state != svr2::RAFTSTATE_LOADED_PART_OF_GROUP) { + return SendE2EError(ctx, target, req.request_id(), COUNTED_ERROR(Replicate_RaftState)); + } + if (msg.group_id() != raft_.loaded.raft->group_id()) { + return SendE2EError(ctx, target, req.request_id(), COUNTED_ERROR(Replicate_GroupMismatch)); + } + // push_state will live for the duration of this replication. + auto push_state = std::make_shared( + raft_.loaded.raft->log().oldest_stored_idx(), target, req); + + // `target` has requested replication from us, so now we need to ship data + // down to it. We do so by sending some number of E2E transactions to `target`, + // each containing a subset of the data. Each call to SendNextReplicationState + // will send one such E2E transaction, wait for it to complete, then send + // another. So, by starting multiple here, we allow ourselves to send many + // at once over the network without waiting for a response from `target`. + // The multiple requests use (a shared pointer to) a single push_state to + // coordinate which data has been sent already, which should be sent in the + // next call to SendNextReplicationState (either here or on a callback to a + // previous one), coordinating when we're done, and remembering which transaction + // to complete when we are. + auto pipeline = enclave_config(ctx)->raft().replication_pipeline(); + for (uint32_t i = 0; i < pipeline && !push_state->finished_sending; i++) { + SendNextReplicationState(ctx, push_state); + } + return error::OK; +} + +void Core::SendNextReplicationState(context::Context* ctx, std::shared_ptr push_state) { + MEASURE_CPU(ctx, cpu_core_repl_send); + CHECK(!push_state->finished_sending); + auto push = ctx->Protobuf(); + auto out = push->mutable_replicate_state_push(); + out->set_replication_id(push_state->replication_id); + out->set_replication_sequence(push_state->replication_sequence++); + out->set_first_log_idx(push_state->logs_from_idx_inclusive); + size_t size = 0; + bool at_commit_idx = false; + auto replication_chunk_bytes = enclave_config(ctx)->raft().replication_chunk_bytes(); + for (auto iter = raft_.loaded.raft->log().At(push_state->logs_from_idx_inclusive); ; iter.Next()) { + if (!iter.Valid() || iter.Index() > raft_.loaded.db_last_applied_log) { + LOG(VERBOSE) << "surpassed commit idx " << raft_.loaded.db_last_applied_log; + at_commit_idx = true; + break; + } + *out->add_entries() = *iter.Entry(); + size += iter.SerializedSize(); + if (size >= replication_chunk_bytes) { break; } + } + + // our db rows represent the state at `raft_.loaded.db_commit`, so we can + // only send them if after this message the requester will be at `raft_.loaded.db_commit` + if (at_commit_idx) { + size_t rows_to_send = + (replication_chunk_bytes - out->ByteSizeLong()) + / db_protocol_->MaxRowSerializedSize(); + if (rows_to_send) { // if we've got space + auto [row_id, err] = raft_.loaded.db->RowsAsProtos(ctx, push_state->db_from_key_exclusive, rows_to_send, out->mutable_rows()); + if (err != error::OK) { + LOG(WARNING) << "Error getting rows as protos: " << err; + if (!push_state->sent_response.exchange(true)) { + SendE2EError(ctx, push_state->target, push_state->tx, err); + } + return; + } + push_state->db_from_key_exclusive = row_id; + if ((size_t) out->rows_size() < rows_to_send) { + out->set_db_to_end(true); + push_state->finished_sending = true; + LOG(INFO) << "Final data being sent"; + } + } + } + *out->mutable_committed_membership() = raft_.loaded.raft->committed_membership().AsProto(); + IDLOG(INFO) << "Replication: sending " << out->entries_size() << " entries (from " + << push_state->logs_from_idx_inclusive << ") and " << out->rows_size() << " rows to " + << push_state->target; + + // Update push state based on our output. + push_state->logs_from_idx_inclusive += out->entries_size(); + bool last_sent_transaction = push_state->finished_sending; + SendE2ETransaction(ctx, push_state->target, *push, true, + [this, push_state, last_sent_transaction](context::Context* ctx, error::Error err, const e2e::TransactionResponse* resp) { + if (push_state->sent_response.load()) { + return; + } else if (err != error::OK && !push_state->sent_response.exchange(true)) { + SendE2EError(ctx, push_state->target, push_state->tx, err); + return; + } + ACQUIRE_LOCK(raft_.mu, ctx, lock_core_raft); + // `last_sent_transaction` will be set if this is the last transaction we send. + // `push_state->finished_sending` will be set if we've sent that transaction, + // whether this is it or not. + if (last_sent_transaction && !push_state->sent_response.exchange(true)) { + LOG(INFO) << "All replication state pushes complete, returning success for replication"; + SendE2EError(ctx, push_state->target, push_state->tx, error::OK); + } else if (!push_state->finished_sending) { + SendNextReplicationState(ctx, push_state); + } + }); +} + +error::Error Core::HandleReplicateStatePush(context::Context* ctx, const e2e::ReplicateStatePush& repl) { + ACQUIRE_LOCK(raft_.mu, ctx, lock_core_raft); + MEASURE_CPU(ctx, cpu_core_repl_recv); + if (raft_.state != svr2::RAFTSTATE_LOADING) { + LOG(ERROR) << "Running RequestRaftReplication callback while not loading"; + return COUNTED_ERROR(Core_RaftState); + } + if (raft_.loading.log->next_idx() > 1 && repl.first_log_idx() != raft_.loading.log->next_idx()) { + LOG(ERROR) << "log index mismatch: log.next=" << raft_.loading.log->next_idx() + << " repl.first=" << repl.first_log_idx(); + return COUNTED_ERROR(Replicate_LogIndexMismatch); + } else if (!repl.has_committed_membership()) { + return COUNTED_ERROR(Replicate_MissingCommittedMembership); + } else if (repl.replication_id() != raft_.loading.replication_id) { + return COUNTED_ERROR(Replicate_ReplicationID); + } else if (repl.replication_sequence() != raft_.loading.replication_sequence++) { + return COUNTED_ERROR(Replicate_ReplicationSequence); + } + error::Error err; + std::tie(raft_.loading.mem, err) = raft::Membership::FromProto(repl.committed_membership()); + RETURN_IF_ERROR(err); + LOG(INFO) << "received " << repl.entries_size() << " logs starting at " << repl.first_log_idx() + << " and " << repl.rows_size() << " database rows (have " << raft_.loading.db->row_count() << " rows)"; + + raft::Log* log = raft_.loading.log.get(); + // We could be receiving the first set of entries from a replica's truncated set of + // logs. In that case, if we were to append the first entry as log index 1, we'd have + // a mismatch between our log index and theirs. So, when our log is empty, use their + // first log index to set what our next index will be. + if (log->empty()) { + log->SetNextIdx(repl.first_log_idx()); + } + // The `ReplicateStateResponse` we are processing contains log entries that have been + // committed by the sender and db rows that reflect the state up to the last log sent. + // This leaves us with three possible scenarios for each log entry we in this response: + // + // 1. The log entry affects a row that is also sent in this response. In this case the sender + // has already applied this log entry and we MUST NOT apply it again. + // 2. The log entry affects a row out side the range of rows that has been sent. In this case + // the sender will send that row with this log applied in a later message. We MUST NOT + // apply this log. + // 3. The log entry affects a row in the range that had been sent before this request (less + // than or equal to the current max key of the loading database). The sender will not + // send this row again and we MUST apply the log. + // + // At this point, before we add the new rows to the loading database, if a log entry has + // a backup_id greater than the current max key of the loading database then we are + // in situation (1) or (2) and MUST NOT apply the log. Otherwise we are in situation (3) and + // MUST apply the log. + // + // `MaybeApplyLogToReplicatingDatabase` will apply logs according to this rule. Once + // these logs are selectively applied we can add the rows to the loading database. + for (int i = 0; i < repl.entries_size(); i++) { + const auto& entry = repl.entries(i); + // All of our logs are committed logs, so we allow truncation up to the point where + // we only have our most recent entry in the log. + RETURN_IF_ERROR(log->Append(entry, log->last_idx())); + RETURN_IF_ERROR(MaybeApplyLogToReplicatingDatabase(ctx, entry)); + } + LOG(VERBOSE) << "Now have logs in [" << log->oldest_stored_idx() << ", " << log->last_idx() << "]"; + if (repl.rows_size()) { + + // Ensure that rows are provided in order. We use a pointer to avoid excess + // string copies. By the end of this block, *order will point to the largest + // backup ID, which we can use to set `lexigraphically_largest_row_loaded_into_db`. + auto [last, err] = raft_.loading.db->LoadRowsFromProtos(ctx, repl.rows()); + if (last <= raft_.loading.lexigraphically_largest_row_loaded_into_db) { + return COUNTED_ERROR(Core_ReplicationOutOfOrder); + } + raft_.loading.lexigraphically_largest_row_loaded_into_db = last; + } + return error::OK; +} + +// Apply log entries to the loading database if they are in the database's currently loaded range. +error::Error Core::MaybeApplyLogToReplicatingDatabase(context::Context* ctx, const raft::LogEntry& entry) { + if (raft_.state != svr2::RAFTSTATE_LOADING || + raft_.loading.db.get() == nullptr) { + return COUNTED_ERROR(Core_RaftState); + } else if (raft_.loading.lexigraphically_largest_row_loaded_into_db.empty() || entry.data().size() == 0) { + // We don't want to apply this log to the database, since either we have no rows in the database or this is not a client log. + return error::OK; + } + auto clog = db_protocol_->LogPB(ctx); + if (!clog->ParseFromString(entry.data())) { + return COUNTED_ERROR(Core_ReplicatedLogSerialization); + } + if (raft_.loading.lexigraphically_largest_row_loaded_into_db < db_protocol_->LogKey(*clog)) { + return error::OK; + } + RETURN_IF_ERROR(db_protocol_->ValidateClientLog(*clog)); + raft_.loading.db->Run(ctx, *clog); + return error::OK; +} + +error::Error Core::HandleRequestRaftMembership(context::Context* ctx, const peerid::PeerID& from, e2e::TransactionResponse* resp) { + IDLOG(VERBOSE) << "HandleRequestRaftMembership " << from; + ACQUIRE_LOCK(raft_.mu, ctx, lock_core_raft); + if (raft_.state != svr2::RAFTSTATE_LOADED_PART_OF_GROUP) { + return COUNTED_ERROR(Core_RaftState); + } + std::string peer_string = from.AsString(); + raft::ReplicaGroup g = raft_.loaded.raft->membership().AsProto(); + for (int i = 0; i < g.replicas_size(); i++) { + if (g.replicas(i).peer_id() == peer_string) { + return COUNTED_ERROR(Core_DuplicateMembershipPeer); + } + } + g.add_replicas()->set_peer_id(peer_string); + auto [loc, err] = raft_.loaded.raft->ReplicaGroupChange(ctx, g); + if (err == error::OK) { + RaftStep(ctx); + resp->mutable_raft_membership_response()->MergeFrom(loc); + } + return err; +} + +error::Error Core::HandleRequestRaftVoting(context::Context* ctx, const peerid::PeerID& from, e2e::TransactionResponse* resp) { + IDLOG(VERBOSE) << "HandleRequestRaftVoting " << from; + ACQUIRE_LOCK(raft_.mu, ctx, lock_core_raft); + if (raft_.state != svr2::RAFTSTATE_LOADED_PART_OF_GROUP) { + return COUNTED_ERROR(Core_RaftState); + } + if (raft_.loaded.raft->membership().all_replicas().count(from) != 1) { + return COUNTED_ERROR(Core_VotingRequestedForNonMember); + } else if (raft_.loaded.raft->membership().voting_replicas().count(from) != 0) { + return COUNTED_ERROR(Core_VotingRequestedForVotingMember); + } + + // This does not respect the max_voting attribute of the RaftConfig. That's + // fine, though, because the leader will enforce that before accepting this + // change. + raft::ReplicaGroup g = raft_.loaded.raft->membership().AsProto(); + for (int i = 0; i < g.replicas_size(); i++) { + if (g.replicas(i).peer_id() == from.AsString()) { + g.mutable_replicas(i)->set_voting(true); + break; + } + } + auto [loc, err] = raft_.loaded.raft->ReplicaGroupChange(ctx, g); + if (err == error::OK) { + RaftStep(ctx); + resp->mutable_raft_voting_response()->MergeFrom(loc); + } + return err; +} + +error::Error Core::HandleRaftWrite(context::Context* ctx, const std::string& data, e2e::TransactionResponse* resp) { + LOG(VERBOSE) << "HandleRaftWrite"; + ACQUIRE_LOCK(raft_.mu, ctx, lock_core_raft); + if (raft_.state != svr2::RAFTSTATE_LOADED_PART_OF_GROUP) { + return COUNTED_ERROR(Core_RaftState); + } + if (raft_.loaded.raft->membership().voting_replicas().size() < raft_.loaded.group_config.min_voting_replicas()) { + return COUNTED_ERROR(Core_NotEnoughVotingReplicas); + } + auto [loc, err] = raft_.loaded.raft->ClientRequest(ctx, data); + if (err == error::OK) { + RaftStep(ctx); + resp->mutable_raft_write()->MergeFrom(loc); + } + return err; +} + +void Core::HandleNewTimestamp(context::Context* ctx, const peerid::PeerID& from, uint64_t unix_secs) { + clock_.SetRemoteTime(ctx, from, unix_secs); + MaybeUpdateGroupTime(ctx); +} + +error::Error Core::HandlePeerRequestedRaftRemoval(context::Context* ctx, const peerid::PeerID& from, internal::TransactionID tx) { + IDLOG(VERBOSE) << "HandlePeerRequestedRaftRemoval " << from; + ACQUIRE_LOCK(raft_.mu, ctx, lock_core_raft); + if (raft_.state != svr2::RAFTSTATE_LOADED_PART_OF_GROUP) { + SendE2EError(ctx, from, tx, COUNTED_ERROR(Core_RaftState)); + return error::OK; + } + std::string peer_string = from.AsString(); + raft::ReplicaGroup g = raft_.loaded.raft->membership().AsProto(); + raft::ReplicaGroup next = g; + next.clear_replicas(); + bool found_peer = false; + for (int i = 0; i < g.replicas_size(); i++) { + if (g.replicas(i).peer_id() == peer_string) { + found_peer = true; + } else { + *next.add_replicas() = g.replicas(i); + } + } + if (!found_peer) { + SendE2EError(ctx, from, tx, COUNTED_ERROR(Core_RemoveNonexistentMember)); + return error::OK; + } + auto [loc, err] = raft_.loaded.raft->ReplicaGroupChange(ctx, next); + if (err != error::OK) { + SendE2EError(ctx, from, tx, err); + return error::OK; + } + peerid::PeerID from_copy = from; + AddLogTransaction(ctx, loc, [this, f = std::move(from_copy), tx]( + context::Context* ctx, + error::Error err, + const raft::LogEntry* entry, + const db::DB::Response* resp) { + SendE2EError(ctx, f, tx, err); + }); + RaftStep(ctx); + return error::OK; +} + +void Core::RaftStep(context::Context* ctx) { + CHECK(raft_.state == svr2::RAFTSTATE_LOADED_PART_OF_GROUP + || raft_.state == svr2::RAFTSTATE_LOADED_REQUESTING_MEMBERSHIP); + RaftSendMessages(ctx); + RaftHandleCommittedLogs(ctx); +} + +void Core::RaftSendMessages(context::Context* ctx) { + // Send out any messages that Raft has for us. + std::vector messages = raft_.loaded.raft->SendableMessages(); + for (size_t i = 0; i < messages.size(); i++) { + std::set send_to; + if (messages[i].to().has_value()) { + send_to.insert(*messages[i].to()); + } else { + send_to = raft_.loaded.raft->peers(); + } + const raft::RaftMessage& raft_msg = messages[i].message(); + for (const auto& peer : send_to) { + auto e2e_msg = ctx->Protobuf(); + e2e_msg->mutable_raft_message()->MergeFrom(raft_msg); + error::Error peer_err = peer_manager_->SendToPeer(ctx, peer, *e2e_msg); + if (peer_err != error::OK) { + // If we've failed here, our peer is probably in a DISCONNECTED state. + // This will be handled eventually by having the peers reset themselves, + // at which point we'll get a new `connected` e2e message, which will + // call Raft's ResetPeer() and restart sends of messages to this peer. + LOG(WARNING) << "failed to generate peer raft message to " << peer << ": " << peer_err; + continue; + } + } + } +} + +void Core::AddLogTransaction(context::Context* ctx, const raft::LogLocation& loc, LogTransactionCallback cb) { + ACQUIRE_LOCK(outstanding_log_transactions_mu_, ctx, lock_core_log_txns); + LogTransaction log_tx = { + .term = loc.term(), + .cb = cb, + .expected_hash_chain = loc.hash_chain(), + }; + outstanding_log_transactions_.emplace(loc.idx(), std::move(log_tx)); +} + +Core::LogTransactionCallback Core::ClientLogTransaction(context::Context* ctx, client::ClientID client_id, internal::TransactionID tx) { + // Record information about this ClientLog message so we can respond to the client later. + return [this, client_id, tx]( + context::Context* ctx, + error::Error err, + const raft::LogEntry* entry, + const db::DB::Response* response) { + if (err == error::Core_LogTransactionCancelled) { + COUNTER(core, client_transaction_cancelled)->Increment(); + LOG(VERBOSE) << "- client " << client_id << " - cancelled"; + ReplyWithError(ctx, tx, COUNTED_ERROR(Client_TransactionCancelled)); + client_manager_->RemoveClient(ctx, client_id); + } else if (err != error::OK) { + COUNTER(core, client_transaction_error)->Increment(); + LOG(VERBOSE) << "- client " << client_id << " - error"; + ReplyWithError(ctx, tx, err); + client_manager_->RemoveClient(ctx, client_id); + } else if (response == nullptr) { + COUNTER(core, client_transaction_invalid)->Increment(); + LOG(VERBOSE) << "- client " << client_id << " - invalid"; + ReplyWithError(ctx, tx, COUNTED_ERROR(Client_TransactionInvalid)); + client_manager_->RemoveClient(ctx, client_id); + } else if ( + client::Client* client = client_manager_->GetClient(ctx, client_id); + client == nullptr) { + COUNTER(core, client_transaction_dne)->Increment(); + LOG(VERBOSE) << "- client " << client_id << " - does_not_exist"; + ReplyWithError(ctx, tx, COUNTED_ERROR(Client_AlreadyClosed)); + client_manager_->RemoveClient(ctx, client_id); + } else if ( + auto [ciphertext, encrypt_err] = client->EncryptResponse(ctx, *response); + encrypt_err != error::OK) { + COUNTER(core, client_transaction_encrypterr)->Increment(); + LOG(VERBOSE) << "- client " << client_id << " - encrypt_fail:" << encrypt_err; + ReplyWithError(ctx, tx, encrypt_err); + client_manager_->RemoveClient(ctx, client_id); + } else { + COUNTER(core, client_transaction_success)->Increment(); + LOG(VERBOSE) << "- client " << client_id << " - success"; + auto enclave_msg = ctx->Protobuf(); + auto resp = enclave_msg->mutable_h2e_response(); + resp->set_request_id(tx); + auto existing_client = resp->mutable_existing_client_reply(); + *existing_client->mutable_data() = std::move(ciphertext); + sender::Send(*enclave_msg); + } + }; +} + +error::Error Core::RaftWriteLogTransaction(context::Context* ctx, const std::string& data, Core::LogTransactionCallback cb) { + ACQUIRE_LOCK(raft_.mu, ctx, lock_core_raft); + if (raft_.state != svr2::RAFTSTATE_LOADED_PART_OF_GROUP) { + return COUNTED_ERROR(Core_RaftState); + } + if (raft_.loaded.raft->is_leader()) { + if (raft_.loaded.raft->membership().voting_replicas().size() < raft_.loaded.group_config.min_voting_replicas()) { + return COUNTED_ERROR(Core_NotEnoughVotingReplicas); + } + // Add the ClientLog message to the Raft log + auto [loc, raft_err] = raft_.loaded.raft->ClientRequest(ctx, data); + if (raft_err != error::OK) { + return raft_err; + } + AddLogTransaction(ctx, loc, cb); + RaftStep(ctx); + } else if (raft_.loaded.raft->leader().has_value()) { + // Forward this ClientLog to the leader to be added to the log + auto txn = ctx->Protobuf(); + txn->set_raft_write(data); + SendE2ETransaction(ctx, *raft_.loaded.raft->leader(), *txn, true, + [this, cb](context::Context* ctx, error::Error err, const e2e::TransactionResponse* resp) { + if (err == error::OK && resp->inner_case() == e2e::TransactionResponse::kStatus) { + err = resp->status(); + } + if (err != error::OK) { + cb(ctx, err, nullptr, nullptr); + return; + } + // Record information about this ClientLog message so we can respond to the client later. + // This replica is responsible for responding to the client (and is the only replica with + // the Noise state that is needed to do that). + if (resp->inner_case() != e2e::TransactionResponse::kRaftWrite) { + cb(ctx, COUNTED_ERROR(Core_IncorrectE2EResponseType), nullptr, nullptr); + return; + } + AddLogTransaction(ctx, resp->raft_write(), cb); + }); + } else { + return COUNTED_ERROR(Core_LeaderUnknown); + } + return error::OK; +} + +void Core::SendTimestamp(context::Context* ctx, peerid::PeerID to, uint64_t unix_secs) { + auto req = ctx->Protobuf(); + req->set_new_timestamp_unix_secs(unix_secs); + SendE2ETransaction( + ctx, to, *req, true, + [unix_secs, to](context::Context* ctx, error::Error err, const e2e::TransactionResponse* resp) { + // Ignore, but log error. + if (err != error::OK) { + LOG(INFO) << "Failed to send timestamp (" << unix_secs << ") to " << to << ": " << err; + } + }); +} + +void Core::SendTimestampToAll(context::Context* ctx) { + auto peers = peer_manager_->ConnectedPeers(ctx); + for (auto peer : peers) { + SendTimestamp(ctx, std::move(peer), clock_.GetLocalTime()); + } + util::Ticks next = std::max(1U, enclave_config(ctx)->send_timestamp_ticks()); + timeout_.SetTimeout(ctx, next, [this](context::Context* ctx) { + SendTimestampToAll(ctx); + }); +} + +error::Error Core::SendE2EError(context::Context* ctx, const peerid::PeerID& from, internal::TransactionID id, error::Error err) { + auto e2e = ctx->Protobuf(); + auto out = e2e->mutable_transaction_response(); + out->set_request_id(id); + out->set_status(err); + if (out->status() != error::OK) { + IDLOG(VERBOSE) << "request " << id << " from " << from << " error: " << err; + } + return peer_manager_->SendToPeer(ctx, from, *e2e); +} + +void Core::RaftHandleCommittedLogs(context::Context* ctx) { + // See if Raft has any committed logs for us. + MEASURE_CPU(ctx, cpu_core_committed_logs); + while (true) { + auto [idx, entry] = raft_.loaded.raft->TakeCommittedLog(); + if (idx == 0) { + // There's no additional logs, we're done! + return; + } + raft_.loaded.db_last_applied_log = idx; + LOG(VERBOSE) << "at db_last_applied_log " << idx; + GAUGE(core, last_index_applied_to_db)->Set(idx); + if (entry.has_membership_change()) { + HandleRaftMembershipChange(ctx, idx, entry.term(), entry.membership_change()); + } + db::DB::Response* response = RaftApplyLogToDatabase(ctx, idx, entry); + // Unless this log contained a valid client transaction, + // [response] will be null at this point. + HandleLogTransactionsForRaftLog(ctx, idx, entry, response); + COUNTER(core, raft_log_applied)->Increment(); + } +} + +void Core::HandleRaftMembershipChange( + context::Context* ctx, + raft::LogIdx idx, + raft::TermId term, + const raft::ReplicaGroup& membership_change) { + switch (raft_.state) { + case svr2::RAFTSTATE_LOADED_PART_OF_GROUP: { + if (!ContainsMe(raft_.loaded.raft->me(), membership_change)) { + LOG(WARNING) << "I've been removed from Raft at index " << idx; + raft_.state = svr2::RAFTSTATE_LOADED_REQUESTING_MEMBERSHIP; + ACQUIRE_LOCK(outstanding_log_transactions_mu_, ctx, lock_core_log_txns); + for (auto iter = outstanding_log_transactions_.begin(); + iter != outstanding_log_transactions_.end(); + iter = outstanding_log_transactions_.erase(iter)) { + const auto& log_tx = iter->second; + log_tx.cb(ctx, COUNTED_ERROR(Core_RemovedFromRaft), nullptr, nullptr); + } + } + } break; + case svr2::RAFTSTATE_LOADED_REQUESTING_MEMBERSHIP: { + if (ContainsMe(raft_.loaded.raft->me(), membership_change)) { + LOG(INFO) << "I've been added to Raft at index " << idx; + raft_.state = svr2::RAFTSTATE_LOADED_PART_OF_GROUP; + } + } break; + default: + CHECK(nullptr == "in HandleRaftMembershipChange but not part of group or requesting membership"); + break; + } +} + +db::DB::Response* Core::RaftApplyLogToDatabase( + context::Context* ctx, + raft::LogIdx idx, + const raft::LogEntry& committed_entry) { + if (committed_entry.data().size() == 0) { + // This is an internal-to-Raft log, we don't need to care. + // These are generated on leader election, and will eventually + // be used for membership changes as well. + return nullptr; + } + auto client_log = db_protocol_->LogPB(ctx); + if (!client_log->ParseFromString(committed_entry.data())) { + LOG(ERROR) << "raft log message does not parse: " << idx; + return nullptr; + } + error::Error validate_err = db_protocol_->ValidateClientLog(*client_log); + if (validate_err != error::OK) { + LOG(ERROR) << "raft log message invalid: " << idx << " - " << validate_err; + return nullptr; + } + return raft_.loaded.db->Run(ctx, *client_log); +} + +void Core::HandleLogTransactionsForRaftLog(context::Context* ctx, raft::LogIdx idx, const raft::LogEntry& entry, const db::DB::Response* response) { + // See if this is a log we should handle. + const char* type = + entry.data().size() == 0 + ? "raft_internal" + : response != nullptr + ? "valid_client" + : "invalid"; + LOG(VERBOSE) << "raft log " << idx << " at term " << entry.term() << " - " << type;; + ACQUIRE_LOCK(outstanding_log_transactions_mu_, ctx, lock_core_log_txns); + auto [iter, upper] = outstanding_log_transactions_.equal_range(idx); + for (; iter != upper; iter = outstanding_log_transactions_.erase(iter)) { + const LogTransaction& log_tx = iter->second; + if (log_tx.term != entry.term()) { + COUNTER(core, log_transactions_cancelled)->Increment(); + log_tx.cb(ctx, COUNTED_ERROR(Core_LogTransactionCancelled), nullptr, nullptr); + } else if (log_tx.expected_hash_chain.size() > 0 // ignore hash chain if length is 0 + && !util::ConstantTimeEquals(log_tx.expected_hash_chain, entry.hash_chain())) { + log_tx.cb(ctx, COUNTED_ERROR(Core_InvalidLogTransactionHashChain), nullptr, nullptr); + } else { + COUNTER(core, log_transactions_success)->Increment(); + log_tx.cb(ctx, error::OK, &entry, response); + } + } +} + +void Core::SendE2ETransaction( + context::Context* ctx, + const peerid::PeerID& to, + const e2e::TransactionRequest& req, + bool with_timeout, + E2ECallback callback) { + ACQUIRE_NAMED_LOCK(lock, e2e_txn_mu_, ctx, lock_core_e2e_txns); + internal::TransactionID tx = ++e2e_txn_id_; + auto e2e = ctx->Protobuf(); + e2e->mutable_transaction_request()->MergeFrom(req); + e2e->mutable_transaction_request()->set_request_id(tx); + error::Error err = peer_manager_->SendToPeer(ctx, to, *e2e); + if (err != error::OK) { + IDLOG(VERBOSE) << "failed to start transaction " << tx << " to " << to << ": " << err; + lock.unlock(); + // This is a problematic codepath right now, as we call the callback inline. + // Sometimes, the callback has to acquire a lock that's already acquired + // by SendE2ETransaction's caller. The optimal approach would be to defer + // this callback to some time when the caller has returned. + callback(ctx, err, nullptr); + return; + } + LOG(VERBOSE) << "successfully started transaction " << tx << " to " << to; + timeout::Cancel tc; + if (with_timeout) { + tc = timeout_.SetTimeout(ctx, enclave_config(ctx)->e2e_txn_timeout_ticks(), + [this, tx, to](context::Context* ctx) { + ACQUIRE_NAMED_LOCK(lock, e2e_txn_mu_, ctx, lock_core_e2e_txns); + auto f = outstanding_e2e_transactions_.find(tx); + if (f == outstanding_e2e_transactions_.end()) return; + LOG(INFO) << "e2e transaction " << tx << "to " << to << " timed out"; + E2ECallback cb = std::move(f->second.callback); + outstanding_e2e_transactions_.erase(f); + lock.unlock(); + cb(ctx, error::Core_E2ETransactionTimeout, nullptr); + }); + } + outstanding_e2e_transactions_[tx] = { + .callback = callback, + .timeout_cancel = tc, + }; +} + +} // namespace svr2::core + diff --git a/enclave/core/core.h b/enclave/core/core.h new file mode 100644 index 0000000..74238f0 --- /dev/null +++ b/enclave/core/core.h @@ -0,0 +1,311 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_CORE_CORE_H__ +#define __SVR2_CORE_CORE_H__ + +#include +#include +#include +#include "proto/enclaveconfig.pb.h" +#include "proto/error.pb.h" +#include "proto/msgs.pb.h" +#include "util/macros.h" +#include "peerid/peerid.h" +#include "peers/peers.h" +#include "context/context.h" +#include "raft/log.h" +#include "raft/raft.h" +#include "client/client.h" +#include "sip/hasher.h" +#include "db/db.h" +#include "core/internal.h" +#include "util/macros.h" +#include "util/ticks.h" +#include "timeout/timeout.h" +#include "groupclock/groupclock.h" + +namespace svr2::core { + +// Core is the core singleton of a running enclave. Each running enclave +// should have exactly one of these, created on initialization. +class Core { + public: + DELETE_COPY_AND_ASSIGN(Core); + + // Receive a message from the host. + error::Error Receive(context::Context* ctx, const UntrustedMessage& msg); + + // Peer ID for this core. + const peerid::PeerID& ID() const { return peer_manager_->ID(); } + + // Create a core from a given config. + static std::pair, error::Error> Create( + context::Context* ctx, + const enclaveconfig::InitConfig& config); + +#ifdef IS_TEST + bool serving() const { + util::unique_lock lock(raft_.mu); + return raft_.state == svr2::RAFTSTATE_LOADED_PART_OF_GROUP; + } + bool leader() const { + util::unique_lock lock(raft_.mu); + return raft_.state == svr2::RAFTSTATE_LOADED_PART_OF_GROUP && raft_.loaded.raft->is_leader(); + } + bool voting() const { + util::unique_lock lock(raft_.mu); + return raft_.state == svr2::RAFTSTATE_LOADED_PART_OF_GROUP && raft_.loaded.raft->voting(); + } + size_t num_voting() const { + util::unique_lock lock(raft_.mu); + return raft_.loaded.raft->membership().voting_replicas().size(); + } + size_t num_members() const { + util::unique_lock lock(raft_.mu); + return raft_.loaded.raft->membership().all_replicas().size(); + } + std::set all_replicas() const { + util::unique_lock lock(raft_.mu); + return raft_.loaded.raft->membership().all_replicas(); + } +#endif + + private: + struct ReplicationPushState { + ReplicationPushState(raft::LogIdx idx, const peerid::PeerID& to, const e2e::TransactionRequest& req) + : logs_from_idx_inclusive(idx), + db_from_key_exclusive(""), + finished_sending(false), + target(to), + tx(req.request_id()), + replication_id(req.replicate_state().replication_id()), + replication_sequence(0), + sent_response(false) {} + + raft::LogIdx logs_from_idx_inclusive; // GUARDED_BY(raft_.mu) + std::string db_from_key_exclusive; // GUARDED_BY(raft_.mu) + bool finished_sending; // GUARDED_BY(raft_.mu) + const peerid::PeerID target; + const internal::TransactionID tx; + const uint64_t replication_id; + uint64_t replication_sequence; // GUARDED_BY(raft_mu) + std::atomic sent_response; + }; + + Core(const enclaveconfig::RaftGroupConfig& group_config); + // Init this core object. This function should be + // called exactly once for each Core object, and sould be the first function + // called subsequent to construction. + error::Error Init( + context::Context* ctx, + const enclaveconfig::EnclaveConfig& config, + util::UnixSecs initial_timestamp_unix_secs); + + //// Top-level callers, called by Receive(), and their subfunctions. + + // Handle a request from the host + error::Error HandleHostToEnclave(context::Context* ctx, const HostToEnclaveRequest& msg); + // Handle a request for a new client + void HandleNewClient(context::Context* ctx, const NewClientRequest& msg, internal::TransactionID tx); + // Handle a message being passed through the host to an existing client + error::Error HandleExistingClient(context::Context* ctx, const ExistingClientRequest& msg, internal::TransactionID tx); + // Request that we create a new raft group from scratch, setting ourselves + // as the sole member and leader. This should be done to seed a new + // Raft, after which we should requst JoinRaft instead. + void HandleCreateNewRaftGroupRequest(context::Context* ctx, internal::TransactionID tx) EXCLUDES(raft_.mu); + // Creates a test account within the Raft DB. + error::Error AddTestAccount(context::Context* ctx, uint32_t i); + // Join an existing Raft group. + void HandleJoinRaft(context::Context* ctx, const JoinRaftRequest& msg, internal::TransactionID tx) EXCLUDES(raft_.mu); + // Given a single seed peer, connect to it and get the existing configs. + void JoinRaftFromFirstPeer(context::Context* ctx) REQUIRES(raft_.mu); + // Replicate all data from our existing peer(s) until we've got a full set of data. + void RequestRaftReplication(context::Context* ctx) REQUIRES(raft_.mu); + // Now that we've got a full set of Raft data (logs+db), set up our local Raft objects. + void PromoteRaftToLoaded(context::Context* ctx) REQUIRES(raft_.mu); + // Request to become a (nonvoting) member of the Raft group we have data for. + void RaftRequestMembership(context::Context* ctx, internal::TransactionID tx) REQUIRES(raft_.mu); + // Refresh attestations for peer and client connections. + error::Error HandleRefreshAttestation(context::Context* ctx, bool rotate_key) EXCLUDES(raft_.mu); + // Get the current status of this replica to be returned to the host. + std::pair HandleGetEnclaveStatus(context::Context* ctx) const EXCLUDES(raft_.mu); + // Handle a host-requested delete of a backup ID. + error::Error HandleHostDatabaseRequest(context::Context* ctx, internal::TransactionID tx, const DatabaseRequest& req); + // Reconfigure the replica with new host-supplied configuration. + error::Error HandleReconfigure(context::Context* ctx, internal::TransactionID tx, const enclaveconfig::EnclaveConfig& req) EXCLUDES(raft_.mu); + // If we're the raft leader, give it up. + void HandleRelinquishLeadership(context::Context* ctx, internal::TransactionID tx) EXCLUDES(raft_.mu); + // Request that this replica be removed from the Raft group. + void HandleHostRequestedRaftRemoval(context::Context* ctx, internal::TransactionID tx) EXCLUDES(raft_.mu); + // Compute and return to the host a hash of the current DB. + error::Error HandleHostHashes(context::Context* ctx, internal::TransactionID tx) EXCLUDES(raft_.mu); + + // Handle the inevitable march of time. + void HandleTimerTick(context::Context* ctx, const TimerTick& tick); + // Update our group-based concept of time. + void MaybeUpdateGroupTime(context::Context* ctx) EXCLUDES(raft_.mu); + // If we're in Raft with some other replicas but don't yet have peer connections + // to them, try to establish them. + void ConnectToRaftMembers(context::Context* ctx) REQUIRES(raft_.mu); + // Return either a nullptr, or a replica config (in scope [ctx]) that + // this instance believes should be the next config for this raft group. + raft::ReplicaGroup* NextReplicaGroup(context::Context* ctx) REQUIRES(raft_.mu); + + // Decode a new message proxied from a peer replica through our host. + error::Error HandlePeerMessage(context::Context* ctx, const UntrustedMessage& msg); + // Handle an EnclaveToEnclaveMessage decoded from the peer message + error::Error HandleE2E(context::Context* ctx, const peerid::PeerID& from, const e2e::EnclaveToEnclaveMessage& msg); + // Handle the case where we've just successfully established a connection to the peer `from` + void HandlePeerConnect(context::Context* ctx, const peerid::PeerID& from); + // Handle an enclave-to-enclave transaction requested by a remote peer client. + error::Error HandleE2ETransaction(context::Context* ctx, const peerid::PeerID& from, const e2e::TransactionRequest& msg); + // Handle a request to replicate our state (Raft DB and logs) to `from` + error::Error HandleReplicateStateRequest(context::Context* ctx, const peerid::PeerID& from, const e2e::TransactionRequest& req) EXCLUDES(raft_.mu); + // Send the next set of replicating state to `from`, in the form of a ReplicateStatePush E2E transaction. + void SendNextReplicationState(context::Context* ctx, std::shared_ptr push_state) REQUIRES(raft_.mu); + // Handle receipt of the next piece of state from a server that's replicating their state to us. + error::Error HandleReplicateStatePush(context::Context* ctx, const e2e::ReplicateStatePush& push) EXCLUDES(raft_.mu); + // Handle applying replicated state to an as-yet-unfinished Raft database (in raft_.loading.db) + error::Error MaybeApplyLogToReplicatingDatabase(context::Context* ctx, const raft::LogEntry& entry) REQUIRES(raft_.mu); + // Handle a request to join our Raft group. + error::Error HandleRequestRaftMembership(context::Context* ctx, const peerid::PeerID& from, e2e::TransactionResponse* resp) EXCLUDES(raft_.mu); + // Handle a request to become a voting member of our Raft group. + error::Error HandleRequestRaftVoting(context::Context* ctx, const peerid::PeerID& from, e2e::TransactionResponse* resp) EXCLUDES(raft_.mu); + // Handle a request to write a client log into our Raft group. + error::Error HandleRaftWrite(context::Context* ctx, const std::string& data, e2e::TransactionResponse* resp) EXCLUDES(raft_.mu); + // Handle receipt of a new timestamp supplied by `from`. + void HandleNewTimestamp(context::Context* ctx, const peerid::PeerID& from, uint64_t unix_secs); + // Handle a request to remove the sender from Raft. + error::Error HandlePeerRequestedRaftRemoval(context::Context* ctx, const peerid::PeerID& from, internal::TransactionID tx) EXCLUDES(raft_.mu); + + //// Common or utility functions called by multiple handlers. + + // RaftStep handles sending any outstanding raft messages and applying + // any committed transactions. It should be called after any change to + // Raft state, including receiving a raft message, requesting a client + // log, etc. + void RaftStep(context::Context* ctx) REQUIRES(raft_.mu); + // Send any messages buffered by raft to our peers. + void RaftSendMessages(context::Context* ctx) REQUIRES(raft_.mu); + // See if any logs have been committed since last we looked, and apply them to our + // internal state if there are some. + void RaftHandleCommittedLogs(context::Context* ctx) REQUIRES(raft_.mu); + // Handle a Raft log that changes group membership, which may either + // add us to a group or remove us from our group. + void HandleRaftMembershipChange( + context::Context* ctx, + raft::LogIdx idx, + raft::TermId term, + const raft::ReplicaGroup& membership_change) REQUIRES(raft_.mu); + // Attempt to apply the committed log entry to the db::DB. On success, + // return a db::DB::Response (owned by [ctx]). On failure, return + // nullptr. Regardless, [committed_entry] is considered to be successfully + // committed to the database after this call. + db::DB::Response* RaftApplyLogToDatabase( + context::Context* ctx, + raft::LogIdx idx, + const raft::LogEntry& committed_entry) REQUIRES(raft_.mu); + // HandleLogTransactionsForRaftLog handles any queued log + // transactions in outstanding_log_transactions_ associated with the given + // log entry. + void HandleLogTransactionsForRaftLog( + context::Context* ctx, + raft::LogIdx idx, + const raft::LogEntry& entry, + // response may be null in the case where we failed to parse it from the Raft log. + const db::DB::Response* response) REQUIRES(raft_.mu); + + // Send a local timestamp to remote peer `to`. + void SendTimestamp(context::Context* ctx, peerid::PeerID to, uint64_t unix_seconds); + // Send our local timestamp to all connected peers. + void SendTimestampToAll(context::Context* ctx); + + static error::Error ValidateConfig(const enclaveconfig::EnclaveConfig& config); + static error::Error ValidateConfigChange(const enclaveconfig::EnclaveConfig& old_config, const enclaveconfig::EnclaveConfig& new_config); + + mutable util::mutex config_mu_; + enclaveconfig::EnclaveConfig enclave_config_ GUARDED_BY(config_mu_); + const enclaveconfig::RaftGroupConfig raft_config_template_; + + enclaveconfig::EnclaveConfig* enclave_config(context::Context* ctx) const EXCLUDES(config_mu_); + + std::unique_ptr peer_manager_; + std::unique_ptr client_manager_; + + internal::Raft raft_; + const enclaveconfig::DatabaseVersion db_version_; + const db::DB::Protocol* const db_protocol_; + groupclock::Clock clock_; + + // Handle timeouts. + timeout::Timeout timeout_; + + typedef std::function LogTransactionCallback; + // When we submit a transaction to the log, we get back the idx/term + // at which it should be committed. Later, we see that LogIdx go by, and + // if the term matches, we're in business and can execute the transaction. + // If the term does _not_ match, then this transaction was overridden or + // cancelled by a Raft election. + struct LogTransaction { + raft::TermId term; + LogTransactionCallback cb; + // If the expected_hash_chain is the empty string it is ignored. Otherwise + // if the hash_chain for this long index does not match the + // expected_hash_chain the transaction is aborted. + std::string expected_hash_chain; + }; + // This is a multimap because, if the leader changes, we could possibly + // have multiple transactions mapped to the same log index (with different + // terms). + util::mutex outstanding_log_transactions_mu_; + std::unordered_multimap outstanding_log_transactions_ GUARDED_BY(outstanding_log_transactions_mu_); + // Adds a callback to be run when the log at the given location has been commited. + // NOTE: when cb is called, raft_.mu will be locked already. + void AddLogTransaction(context::Context* ctx, const raft::LogLocation& loc, LogTransactionCallback cb) EXCLUDES(outstanding_log_transactions_mu_); + error::Error RaftWriteLogTransaction(context::Context* ctx, const std::string& data, LogTransactionCallback cb) EXCLUDES(raft_.mu); + LogTransactionCallback ClientLogTransaction(context::Context* ctx, client::ClientID client_id, internal::TransactionID tx); + + // State for transactions that this enclave sends to other enclaves. + // Transactions are kept locally as a map of callbacks (of type + // E2ECallback). On receipt of a response, we look for the appropriate + // callback in the outstanding_e2e_transactions_ map and call it. + util::mutex e2e_txn_mu_ ACQUIRED_AFTER(raft_.mu); + internal::TransactionID e2e_txn_id_ GUARDED_BY(e2e_txn_mu_); + typedef std::function E2ECallback; + struct E2ECall { + E2ECallback callback; + timeout::Cancel timeout_cancel; + }; + std::unordered_map outstanding_e2e_transactions_ GUARDED_BY(e2e_txn_mu_); + // Send an Enclave-to-enclave transaction. + void SendE2ETransaction( + context::Context* ctx, + const peerid::PeerID& to, + const e2e::TransactionRequest& req, + bool with_timeout, // If false, allow to run forever. + E2ECallback callback) EXCLUDES(e2e_txn_mu_); + error::Error SendE2EError(context::Context* ctx, const peerid::PeerID& from, internal::TransactionID id, error::Error err); +}; + +} // namespace svr2::core + +#endif // __SVR2_CORE_CORE_H__ diff --git a/enclave/core/coretest/replicagroup.cc b/enclave/core/coretest/replicagroup.cc new file mode 100644 index 0000000..7ee02d7 --- /dev/null +++ b/enclave/core/coretest/replicagroup.cc @@ -0,0 +1,145 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include "replicagroup.h" + +#include + +namespace svr2::core::test { + +bool ReplicaGroup::IsQuiet() const { + for (const auto& [peer_id, core] : peers_by_id_) { + if (core->active() && core->input_messages().size() > 0) return false; + } + return true; +} + +error::Error ReplicaGroup::SendMessage(peerid::PeerID to, PeerMessage msg) { + peerid::PeerID from; + from.FromString(msg.peer_id()); + PartitionID to_partition = partition_[to]; + PartitionID from_partition = partition_[from]; + + if (to_partition == from_partition) { + LOG(VERBOSE) << "#####################################################"; + LOG(VERBOSE) << "# peer message to " << to << " from " << from; + RETURN_IF_ERROR(peers_by_id_[to]->AddPeerMessage(std::move(msg))); + } else { + LOG(VERBOSE) << "#---------------------------------------------------#"; + LOG(VERBOSE) << "# BLOCKED peer message to " << to << " from " << from; + blocked_peer_messages_[to].emplace_back(std::move(msg)); + } + return error::OK; +} + +error::Error ReplicaGroup::PassMessagesUntilQuiet(PartitionID pid) { + error::Error err = error::OK; + while (!IsQuiet()) { + for (auto& core : peers_) { + if (pid == FULL_GROUP_PARTITION_ID || + partition_.find(core->ID())->second == pid) { + RETURN_IF_ERROR(core->ProcessIncomingMessage()); + RETURN_IF_ERROR(core->ForwardOutgoingMessages()); + } + } + } + return err; +} + +error::Error ReplicaGroup::ProcessAllH2EResponses(PartitionID pid) { + for (auto& core : peers_) { + if (pid == FULL_GROUP_PARTITION_ID || + partition_.find(core->ID())->second == pid) { + RETURN_IF_ERROR(core->ProcessAllH2EResponses()); + } + } + return error::OK; +} + +error::Error ReplicaGroup::TickAllTimers(PartitionID pid) { + for (auto& [peer_id, core] : peers_by_id_) { + if (pid == FULL_GROUP_PARTITION_ID || + partition_.find(peer_id)->second == pid) { + RETURN_IF_ERROR(core->TimerTick()); + RETURN_IF_ERROR(core->ProcessIncomingMessage()); + } + } + + for (auto& [peer_id, core] : peers_by_id_) { + if (pid == FULL_GROUP_PARTITION_ID || + partition_.find(peer_id)->second == pid) { + RETURN_IF_ERROR(core->ForwardOutgoingMessages()); + } + } + return error::OK; +} + +void ReplicaGroup::TickTock(bool ignore_h2e_errors) { + TickTock(FULL_GROUP_PARTITION_ID, ignore_h2e_errors); +} + +void ReplicaGroup::TickTock(PartitionID pid, bool ignore_h2e_errors) { + ASSERT_EQ(error::OK, TickAllTimers(pid)); + ASSERT_EQ(error::OK, PassMessagesUntilQuiet()); + auto err = ProcessAllH2EResponses(); + if (!ignore_h2e_errors) ASSERT_EQ(error::OK, err); +} + +void ReplicaGroup::add_peer() { + peers_.emplace_back(std::make_unique(*this)); + auto peer = peers_.rbegin()->get(); + peers_by_id_[peer->ID()] = peer; +} + +void ReplicaGroup::Init(enclaveconfig::InitConfig cfg, + size_t initial_voting, + size_t initial_nonvoting, size_t initial_nonmember) { + init_config_ = cfg; + enclave_config_ = cfg.enclave_config(); + size_t num_cores = initial_voting + initial_nonvoting + initial_nonmember; + LOG(INFO) << "ADDING " << num_cores << " PEERS"; + for (size_t i = 0; i < num_cores; ++i) { + add_peer(); + } + + LOG(INFO) << "CREATING RAFT"; + ASSERT_EQ(error::OK, peers_[0]->CreateNewRaftGroup()); + ASSERT_EQ(error::OK, PassMessagesUntilQuiet()); + for (size_t i = 1; i < initial_voting + initial_nonvoting; ++i) { + LOG(INFO) << "JOINING " << i << " of " << (initial_nonvoting + initial_voting); + // request to join raft from the previous peer (so not always the leader) + ASSERT_EQ(error::OK, peers_[i]->JoinRaft(peers_[i - 1]->ID())); + ASSERT_EQ(error::OK, PassMessagesUntilQuiet()); + CHECK(peers_[i]->serving()); + } + + for (size_t i = 1; i < initial_voting; ++i) { + LOG(INFO) << "VOTING " << i << " of " << initial_voting; + ASSERT_EQ(error::OK, peers_[i]->RequestVoting()); + ASSERT_EQ(error::OK, PassMessagesUntilQuiet()); + } + + std::vector partition_members; + for (const auto& peer : peers_) { + auto peer_id = peer->ID(); + partition_[peer_id] = 1; + partition_members.emplace_back(std::move(peer_id)); + } + partition_members_.emplace(std::make_pair(1, partition_members)); + + ASSERT_EQ(error::OK, PassMessagesUntilQuiet()); +} +void ReplicaGroup::ForwardBlockedMessages() { + for (auto&& [peer_id, msgs] : blocked_peer_messages_) { + for (auto&& msg : msgs) { + peerid::PeerID from; + from.FromString(msg.peer_id()); + LOG(VERBOSE) << "#******************************************#"; + LOG(VERBOSE) << "# Forwarding blocked peer message to " << peer_id + << " from " << from; + peers_by_id_[peer_id]->AddPeerMessage(std::move(msg)); + } + } +} + +}; // namespace svr2::core::test diff --git a/enclave/core/coretest/replicagroup.h b/enclave/core/coretest/replicagroup.h new file mode 100644 index 0000000..ec20286 --- /dev/null +++ b/enclave/core/coretest/replicagroup.h @@ -0,0 +1,249 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_CORE_CORETEST_REPLICAGROUP_H__ +#define __SVR2_CORE_CORETEST_REPLICAGROUP_H__ + +#include +#include +#include + +#include "core/core.h" +#include "peerid/peerid.h" +#include "testingcore.h" +#include "util/macros.h" + +namespace svr2::core::test { +using PartitionID = uint32_t; +using TestingCoreMap = std::map; +using PartitionMap = std::map; +using ReversePartitionMap = std::map>; + +template +std::pair LargestPartition( + const std::map &partition) { + std::map counts; + + size_t max_count{0}; + PartitionID largest_partition{0}; + for (const auto &[key, val] : partition) { + counts[val]++; + if (counts[val] > max_count) { + max_count = counts[val]; + largest_partition = val; + } + } + return std::make_pair(largest_partition, max_count); +} + +class ReplicaGroup { + // This PartitionID represents the full replica group and is used + // internally to override an existing partition. + static const PartitionID FULL_GROUP_PARTITION_ID = UINT32_MAX; + + public: + ReplicaGroup() {} + // This is not copyable because `peers_` is not copyable + DELETE_COPY_AND_ASSIGN(ReplicaGroup); + + const TestingCore *get_core(size_t i) const { + CHECK(i < peers_.size()); + return peers_[i].get(); + } + + TestingCore *get_core(size_t i) { + CHECK(i < peers_.size()); + return peers_[i].get(); + } + + TestingCore *get_leader_core() { return get_core(GroupLeaderIndex()); } + const TestingCore *get_leader_core() const { + return get_core(GroupLeaderIndex()); + } + TestingCore *get_voting_nonleader_core() { + auto peer = std::find_if(peers_.cbegin(), peers_.cend(), [](const auto &p) { + return p->voting() && !p->leader(); + }); + return peer != peers_.cend() ? peer->get() : nullptr; + } + + size_t partition_size(size_t i) const { + auto id = peers_[i]->ID(); + PartitionID part_id = partition_.find(id)->second; + return partition_members_.find(part_id)->second.size(); + } + + enclaveconfig::EnclaveConfig get_enclave_config() const { + return enclave_config_; + } + enclaveconfig::InitConfig get_init_config() const { + return init_config_; + } + + size_t num_voting() const { return get_leader_core()->num_voting(); } + + size_t num_serving() const { return get_leader_core()->num_serving(); } + + /*** + * Creates and initializes TestingCores with given configuration. The first + * `initial_voting` items in the returned vector will be accepted voting + * members, the next `initial_nonvoting` will be up-to-date non-voting + * members, and the rest will be connected non-members + */ + void Init(enclaveconfig::InitConfig cfg, + size_t initial_voting, + size_t initial_nonvoting, size_t initial_nonmember); + /** + * @brief Check whether any replicas have messages to process + * + * @return true Some replica has a message to process + * @return false No messages to process + */ + bool IsQuiet() const; + + /** + * @brief Get the ID of the group leader if a quorum with a leader exists. + * + * @return peerid::PeerID A valid ID if a quorum is possible and a leader + * exists + */ + peerid::PeerID GroupLeader() const { + auto index = GroupLeaderIndex(); + return index < peers_.size() ? peers_[index]->ID() : peerid::PeerID(); + } + + /** + * @brief Get the index of the group leader if a quorum with a leader exists. + * + * @return size_t SIZE_MAX if no leader is possible, index of the leader + * otherwise. + */ + size_t GroupLeaderIndex() const { + auto [largest_partition, partition_size] = LargestPartition(partition_); + auto found = std::find_if( + peers_.cbegin(), peers_.cend(), + [this, largest_partition = largest_partition](const auto &p) { + return p->leader() && p->active() && + partition_.find(p->ID())->second == largest_partition; + }); + return found - peers_.cbegin(); + } + /** + * @brief Find ID of group leader in a peer's partition + * + * @param peer_id ID of peer looking for reachable leader + * @return peerid::PeerID ID of a replica that (1) believes it is leader and + * (2) is in same partition as peer_id OR, if not found, returns invalid + * PeerID. + */ + peerid::PeerID GroupLeaderInPartition(peerid::PeerID peer_id) const { + auto found = std::find_if(peers_.cbegin(), peers_.cend(), + [this, peer_id](const auto &p) { + return p->leader() && p->active() && + partition_.find(p->ID())->second == + partition_.find(peer_id)->second; + }); + return found == peers_.cend() ? peerid::PeerID() : (*found)->ID(); + } + /** + * @brief Find index of group leader in a peer's partition + * + * @param peer_id ID of peer looking for reachable leader + * @return size_t of a replica that (1) believes it is leader and (2) is + * in same partition as peer_id OR, if not found, returns peers_.size(). + */ + size_t GroupLeaderIndexInPartition(peerid::PeerID peer_id) const { + auto found = std::find_if(peers_.cbegin(), peers_.cend(), + [this, peer_id](const auto &p) { + return p->leader() && p->active() && + partition_.find(p->ID())->second == + partition_.find(peer_id)->second; + }); + return found - peers_.cbegin(); + } + + /** + * @brief Send a message (through the `replica_group_` fabric) to a peer. + * + * @param to Recipient ID + * @param msg + * @return error::Error Error from `TestingCore::AddPeerMessage` or + * `error::OK`. + */ + error::Error SendMessage(peerid::PeerID to, PeerMessage msg); + /** + * @brief All peers in a partition process incoming messages then forward + * resulting outgoing messages until there are no more incoming messages to + * process + * + * @param pid Optional partition ID. If not provided then partitioning is + * ignored and it applies to full group + * @return error::Error + */ + error::Error PassMessagesUntilQuiet( + PartitionID pid = FULL_GROUP_PARTITION_ID); + /** + * @brief All peers in a partition process all responses from enclaves to + * hosts. + * + * @param pid Optional partition ID. If not provided then partitioning is + * ignored and it applies to full group + * @return error::Error returns any error from a HostToEnclaveResponse + */ + error::Error ProcessAllH2EResponses( + PartitionID pid = FULL_GROUP_PARTITION_ID); + /** + * @brief All peers in a partition get a timer tick, process it, then forward + * any outgoung messages + * + * @param pid + * @return error::Error + */ + error::Error TickAllTimers(PartitionID pid = FULL_GROUP_PARTITION_ID); + /** + * @brief Tick all timers, pass messages until quiet, and then optionally + * check to see if any errors came back in the HostToEnclaveResponses + * + * @param ignore_h2e_errors + */ + void TickTock(bool ignore_h2e_errors); + void TickTock(PartitionID pid, bool ignore_h2e_errors); + + void CreatePartition(std::map partition) { + partition_.clear(); + partition_members_.clear(); + + // map the array indices to PeerIDs + for (auto [idx, partition_id] : partition) { + auto peer_id = get_core(idx)->ID(); + partition_[peer_id] = partition_id; + partition_members_[partition_id].emplace_back(peer_id); + } + } + + void ClearPartition() { + partition_.clear(); + partition_members_.clear(); + for (const auto &peer : peers_) { + partition_[peer->ID()] = 1; + partition_members_[1].emplace_back(peer->ID()); + } + } + + void ForwardBlockedMessages(); + void ClearBlockedMessages() { blocked_peer_messages_.clear(); } + + private: + void add_peer(); + enclaveconfig::EnclaveConfig enclave_config_; + enclaveconfig::InitConfig init_config_; + std::vector> peers_; + + TestingCoreMap peers_by_id_; + PartitionMap partition_; + ReversePartitionMap partition_members_; + std::map> blocked_peer_messages_; +}; + +}; // namespace svr2::core::test +#endif // __SVR2_CORE_CORETEST_REPLICAGROUP_H__ diff --git a/enclave/core/coretest/testingclient.cc b/enclave/core/coretest/testingclient.cc new file mode 100644 index 0000000..f14dcf2 --- /dev/null +++ b/enclave/core/coretest/testingclient.cc @@ -0,0 +1,184 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include "testingclient.h" + +#include +#include + +#include "testingcore.h" +#include "util/bytes.h" + +#define NOISE_OK(x) \ + do { \ + int out = (x); \ + if (out != NOISE_ERROR_NONE) { \ + char buf[64]; \ + noise_strerror(out, buf, sizeof(buf)); \ + ASSERT_EQ(out, NOISE_ERROR_NONE) << "Noise error: " << buf; \ + } \ + } while (0) + +namespace svr2::core::test { + +using svr2::util::ByteArrayToString; + +TestingClient::TestingClient(TestingCore& core, const std::string& authenticated_id) + : core_(core), + client_authenticated_id_(authenticated_id), + hs_(noise::WrapHandshakeState(nullptr)), + tx_(noise::WrapCipherState(nullptr)), + rx_(noise::WrapCipherState(nullptr)) {} + +void TestingClient::RequestHandshake() { + state_ = State::HANDSHAKING; + ASSERT_EQ(error::OK, core_.NewClientRequest(this, client_authenticated_id_)); + NoiseHandshakeState* hsp; + NOISE_OK(noise_handshakestate_new_by_id(&hsp, &client::client_protocol, + NOISE_ROLE_INITIATOR)); + hs_ = noise::WrapHandshakeState(hsp); +} + +void TestingClient::RequestBackup(SecretData data, PIN pin, uint32_t tries) { + LOG(INFO) << "sending backup request"; + + client::Request req; + auto b = req.mutable_backup(); + b->set_data(ByteArrayToString(data)); + b->set_pin(ByteArrayToString(pin)); + b->set_max_tries(tries); + + // serialize and encrypt + std::string req_str; + ASSERT_TRUE(req.SerializeToString(&req_str)); + auto [ciphertext, encrypt_err] = noise::Encrypt(tx_.get(), req_str); + ASSERT_EQ(error::OK, encrypt_err); + ASSERT_EQ(error::OK, + core_.ExistingClientRequest(this, client_id_, ciphertext)); + state_ = State::AWAITING_BACKUP; +} + +void TestingClient::RequestExpose(SecretData data) { + LOG(INFO) << "sending expose request"; + + client::Request req; + auto b = req.mutable_expose(); + b->set_data(ByteArrayToString(data)); + + // serialize and encrypt + std::string req_str; + ASSERT_TRUE(req.SerializeToString(&req_str)); + auto [ciphertext, encrypt_err] = noise::Encrypt(tx_.get(), req_str); + ASSERT_EQ(error::OK, encrypt_err); + ASSERT_EQ(error::OK, + core_.ExistingClientRequest(this, client_id_, ciphertext)); + state_ = State::AWAITING_AVAILABLE; +} + +void TestingClient::RequestRestore(PIN pin) { + LOG(INFO) << "sending restore request"; + + client::Request req; + auto b = req.mutable_restore(); + b->set_pin(ByteArrayToString(pin)); + + // serialize and encrypt + std::string req_str; + ASSERT_TRUE(req.SerializeToString(&req_str)); + auto [ciphertext, encrypt_err] = noise::Encrypt(tx_.get(), req_str); + ASSERT_EQ(error::OK, encrypt_err); + ASSERT_EQ(error::OK, + core_.ExistingClientRequest(this, client_id_, ciphertext)); + state_ = State::AWAITING_RESTORE; +} + +void TestingClient::HandleNewClientReply(NewClientReply ncr) { + client_id_ = ncr.client_id(); + ASSERT_GT(client_id_, 0ul); + LOG(VERBOSE) << "new client " << client_id_; + + auto hsp = hs_.get(); + auto hs_msg = ncr.handshake_start(); + NOISE_OK(noise_dhstate_set_public_key( + noise_handshakestate_get_remote_public_key_dh(hsp), + noise::StrU8Ptr(hs_msg.test_only_pubkey()), + hs_msg.test_only_pubkey().size())); + + NOISE_OK(noise_handshakestate_start(hsp)); + ASSERT_EQ(NOISE_ACTION_WRITE_MESSAGE, noise_handshakestate_get_action(hsp)); + + // Now pass a message to complete the handshake + std::string data; + data.resize(noise::HANDSHAKE_INIT_SIZE, '\0'); + NoiseBuffer write_buf = noise::BufferOutputFromString(&data); + NOISE_OK(noise_handshakestate_write_message(hsp, &write_buf, nullptr)); + data.resize(write_buf.size, '\0'); + + core_.ExistingClientRequest(this, client_id_, data); + // now we wait for the existing client reply to finish the handshake +} + +void TestingClient::FinishHandshake(ExistingClientReply ecr) { + LOG(VERBOSE) << "finish handshake client: " << client_id_; + auto hsp = hs_.get(); + NoiseCipherState* txp; + NoiseCipherState* rxp; + + ASSERT_EQ(NOISE_ACTION_READ_MESSAGE, noise_handshakestate_get_action(hsp)); + NoiseBuffer read_buf = noise::BufferInputFromString(ecr.mutable_data()); + NOISE_OK(noise_handshakestate_read_message(hsp, &read_buf, nullptr)); + ASSERT_EQ(NOISE_ACTION_SPLIT, noise_handshakestate_get_action(hsp)); + NOISE_OK(noise_handshakestate_split(hsp, &txp, &rxp)); + + tx_ = noise::WrapCipherState(txp); + rx_ = noise::WrapCipherState(rxp); + state_ = State::READY; +} + +void TestingClient::DecryptClientReply(ExistingClientReply ecr, + client::Response* rsp) { + auto [plaintext, decrypt_err] = noise::Decrypt(rx_.get(), ecr.data()); + ASSERT_EQ(error::OK, decrypt_err); + + ASSERT_TRUE(rsp->ParseFromString(plaintext)); +} + +void TestingClient::HandleBackupResponse(ExistingClientReply ecr) { + client::Response response; + DecryptClientReply(ecr, &response); + ASSERT_EQ(response.inner_case(), client::Response::kBackup); + backup_response_ = response.backup(); + state_ = State::BACKUP_READY; +} +void TestingClient::HandleExposeResponse(ExistingClientReply ecr) { + client::Response response; + DecryptClientReply(ecr, &response); + ASSERT_EQ(response.inner_case(), client::Response::kExpose); + expose_response_ = response.expose(); + state_ = State::AVAILABLE_READY; +} +void TestingClient::HandleRestoreResponse(ExistingClientReply ecr) { + client::Response response; + DecryptClientReply(ecr, &response); + ASSERT_EQ(response.inner_case(), client::Response::kRestore); + restore_response_ = response.restore(); + state_ = State::RESTORE_READY; +} + +void TestingClient::HandleExistingClientReply(ExistingClientReply ecr) { + LOG(VERBOSE) << "state_: " + << static_cast::type>(state_); + switch (state_) { + case State::HANDSHAKING: + return FinishHandshake(ecr); + case State::AWAITING_BACKUP: + return HandleBackupResponse(ecr); + case State::AWAITING_RESTORE: + return HandleRestoreResponse(ecr); + case State::AWAITING_AVAILABLE: + return HandleExposeResponse(ecr); + default: + CHECK(false); + } +} +}; // namespace svr2::core::test diff --git a/enclave/core/coretest/testingclient.h b/enclave/core/coretest/testingclient.h new file mode 100644 index 0000000..a7a2294 --- /dev/null +++ b/enclave/core/coretest/testingclient.h @@ -0,0 +1,81 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_CORE_CORETEST_CLIENT_H__ +#define __SVR2_CORE_CORETEST_CLIENT_H__ + +#include +#include + +#include "db/db.h" // for BACKUP_ID_SIZE +#include "noise/noise.h" +#include "proto/client.pb.h" + +namespace svr2::core::test { +class TestingCore; + +class TestingClient { + public: + using PIN = std::array; + using SecretData = std::array; + + client::BackupResponse* get_backup_response() { + return state_ == State::BACKUP_READY ? &backup_response_ : nullptr; + } + + client::RestoreResponse* get_restore_response() { + return state_ == State::RESTORE_READY ? &restore_response_ : nullptr; + } + + client::ExposeResponse* get_expose_response() { + return state_ == State::AVAILABLE_READY ? &expose_response_ : nullptr; + } + + // These functions return void so that we can use gtest assertions inside + // them. (gtest asertions that generate a fatal failure can only be used with + // void-returning functions: + // https://chromium.googlesource.com/external/github.com/google/googletest/+/HEAD/docs/advanced.md#assertion-placement) + void RequestHandshake(); + void RequestBackup(SecretData data, PIN pin, uint32_t tries); + void RequestExpose(SecretData data); + void RequestRestore(PIN pin); + + void HandleNewClientReply(NewClientReply ncr); + void HandleExistingClientReply(ExistingClientReply ecr); + + TestingClient(TestingCore& core, const std::string& authenticated_id); + + private: + enum class State { + NO_HANDSHAKE, + HANDSHAKING, + READY, + AWAITING_BACKUP, + AWAITING_RESTORE, + AWAITING_AVAILABLE, + BACKUP_READY, + RESTORE_READY, + AVAILABLE_READY + }; + void FinishHandshake(ExistingClientReply ecr); + void HandleBackupResponse(ExistingClientReply ecr); + void HandleExposeResponse(ExistingClientReply ecr); + void HandleRestoreResponse(ExistingClientReply ecr); + void DecryptClientReply(ExistingClientReply ecr, client::Response* rsp); + + TestingCore& core_; + std::string client_authenticated_id_; + uint64_t client_id_{0}; + State state_{State::NO_HANDSHAKE}; + noise::HandshakeState hs_; + noise::CipherState tx_; + noise::CipherState rx_; + + client::BackupResponse backup_response_; + client::RestoreResponse restore_response_; + client::ExposeResponse expose_response_; +}; + +}; // namespace svr2::core::test + +#endif // __SVR2_CORE_CORETEST_CLIENT_H__ diff --git a/enclave/core/coretest/testingcore.cc b/enclave/core/coretest/testingcore.cc new file mode 100644 index 0000000..414fd1b --- /dev/null +++ b/enclave/core/coretest/testingcore.cc @@ -0,0 +1,290 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include "testingcore.h" + +#include + +#include "util/bytes.h" + +#include "replicagroup.h" +#include "testingclient.h" + +namespace svr2::core::test { + +TestingCore::TestingCore(ReplicaGroup& replica_group) + : replica_group_(replica_group) { + context::Context ctx; + enclaveconfig::InitConfig cfg = replica_group.get_init_config(); + cfg.set_initial_timestamp_unix_secs(timer_secs_); + auto [core, err] = Core::Create(&ctx, cfg); + if (err != error::OK) { + LOG(ERROR) << "Could not create core: " << err; + CHECK(false); + } + core_ = std::move(core); +} + +error::Error TestingCore::ProcessIncomingMessage() { + error::Error result = error::OK; + if (!active() || input_messages_.empty()) { + return result; + } + + // send the commands and other messages to the enclave + LOG(VERBOSE) << "Core " << ID() << " processing first of " + << input_messages_.size() << " messages"; + context::Context ctx; + + // take the input message + auto msg = std::move(input_messages_.front()); + input_messages_.pop_front(); + auto err = core_->Receive(&ctx, msg); + if (err != error::OK) { + // clear the messages and return error + env::test::SentMessages(); + return err; + } + + // get the responses + auto response_msgs = env::test::SentMessages(); + + // process according to type + peerid::PeerID to; + PeerMessage peer_msg; + for (auto& response : response_msgs) { + switch (response.inner_case()) { + case EnclaveMessage::kPeerMessage: + peer_msg = std::move(*response.mutable_peer_message()); + + // read who this message is *to* + to.FromString(peer_msg.peer_id()); + + // Now reset the peer_id in the message to our ID so the + // recipient knows who it is *from* + ID().ToString(peer_msg.mutable_peer_id()); + peer_messages_out_[to].emplace_back(std::move(peer_msg)); + break; + case EnclaveMessage::kH2EResponse: + h2e_responses_out_.emplace_back(response.h2e_response()); + break; + default: + CHECK(false); + } + } + return error::OK; +} + +error::Error TestingCore::ProcessAllIncomingMessages() { + while (!input_messages_.empty()) { + RETURN_IF_ERROR(ProcessIncomingMessage()); + } + return error::OK; +} + +error::Error TestingCore::ProcessNextH2EResponse() { + auto h2e_response = std::move(h2e_responses_out_.front()); + h2e_responses_out_.pop_front(); + auto request_id = h2e_response.request_id(); + auto cl = open_client_requests_[request_id]; + switch (h2e_response.inner_case()) { + case HostToEnclaveResponse::kStatus: + if (error::OK != h2e_response.status()) { + LOG(DEBUG) << ID() << " response for request " << request_id << " error: " << h2e_response.status(); + return h2e_response.status(); + } + break; + case HostToEnclaveResponse::kNewClientReply: + cl->HandleNewClientReply(h2e_response.new_client_reply()); + break; + case HostToEnclaveResponse::kExistingClientReply: + cl->HandleExistingClientReply(h2e_response.existing_client_reply()); + break; + case HostToEnclaveResponse::kGetEnclaveStatusReply: + break; + default: + CHECK(false); + } + return error::OK; +} + +error::Error TestingCore::ProcessAllH2EResponses() { + while (!h2e_responses_out_.empty()) { + RETURN_IF_ERROR(ProcessNextH2EResponse()); + } + return error::OK; +} + +error::Error TestingCore::AddPeerMessage(PeerMessage&& peer_message) { + if (state_ == State::ACTIVE || state_ == State::PAUSED_SAVE_MSGS) { + peerid::PeerID other_id; + other_id.FromString(peer_message.peer_id()); + LOG(VERBOSE) << " core " << ID() << " receiving message from " << other_id; + ::svr2::UntrustedMessage req; + *req.mutable_peer_message() = std::move(peer_message); + input_messages_.emplace_back(std::move(req)); + } + return error::OK; +} + +error::Error TestingCore::ForwardOutgoingMessages() { + for (auto& [to, msgs] : peer_messages_out_) { + for (auto& msg : msgs) { + RETURN_IF_ERROR(replica_group_.SendMessage(to, msg)); + } + } + peer_messages_out_.clear(); + return error::OK; +} + +error::Error TestingCore::ResetPeer(peerid::PeerID peer_id) { + LOG(VERBOSE) << "resetpeerreq " << core_->ID() << " -> " << peer_id; + UntrustedMessage msg; + auto reset_req = msg.mutable_reset_peer(); + peer_id.ToString(reset_req->mutable_peer_id()); + input_messages_.emplace_back(std::move(msg)); + return error::OK; +} + +error::Error TestingCore::PingPeer(peerid::PeerID peer_id) { + LOG(VERBOSE) << "pingreq " << core_->ID() << " -> " << peer_id; + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(next_request_id()); + peer_id.ToString(host->mutable_ping_peer()->mutable_peer_id()); + input_messages_.emplace_back(std::move(msg)); + return error::OK; +} + +error::Error TestingCore::GetEnclaveStatus() { + LOG(VERBOSE) << "getenclavestatus " << core_->ID(); + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(next_request_id()); + + host->set_get_enclave_status(true); + input_messages_.emplace_back(std::move(msg)); + return error::OK; +} + +error::Error TestingCore::TimerTick() { + ++timer_secs_; + LOG(VERBOSE) << "timertick " << core_->ID() << " secs: " << timer_secs_; + UntrustedMessage msg; + msg.mutable_timer_tick()->set_new_timestamp_unix_secs(timer_secs_); + input_messages_.emplace_back(std::move(msg)); + return error::OK; +} + +error::Error TestingCore::CreateNewRaftGroup() { + LOG(VERBOSE) << "createnewraftgroup " << core_->ID(); + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(next_request_id()); + host->set_create_new_raft_group(true); + input_messages_.emplace_back(std::move(msg)); + return error::OK; +} + +error::Error TestingCore::JoinRaft(peerid::PeerID peer_id) { + if (!peer_id.Valid()) { + return error::Peers_InvalidID; + } + LOG(VERBOSE) << "joinraftreq " << core_->ID() << " -> " << peer_id; + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(next_request_id()); + auto req = host->mutable_join_raft(); + peer_id.ToString(req->mutable_peer_id()); + input_messages_.emplace_back(std::move(msg)); + return error::OK; +} +error::Error TestingCore::RequestVoting() { + LOG(VERBOSE) << "requestvoting " << core_->ID(); + + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(next_request_id()); + host->set_request_voting(true); + input_messages_.emplace_back(std::move(msg)); + return error::OK; +} + +error::Error TestingCore::Reconfigure(const enclaveconfig::EnclaveConfig& config) { + LOG(VERBOSE) << "reconfigure " << core_->ID(); + config_ = config; + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(next_request_id()); + host->mutable_reconfigure()->MergeFrom(config); + input_messages_.emplace_back(std::move(msg)); + return error::OK; +} + +error::Error TestingCore::RaftRemoval() { + LOG(VERBOSE) << "raft_removal " << core_->ID(); + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(next_request_id()); + host->set_request_removal(true); + input_messages_.emplace_back(std::move(msg)); + return error::OK; +} + +error::Error TestingCore::DeleteBackup(const std::string& client_authenticated_id) { + LOG(VERBOSE) << "deletebackup " << core_->ID(); + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(next_request_id()); + client::Request delete_; + delete_.mutable_delete_(); + CHECK(delete_.SerializeToString(host->mutable_database_request()->mutable_request())); + host->mutable_database_request()->set_authenticated_id(client_authenticated_id); + input_messages_.emplace_back(std::move(msg)); + return error::OK; +} + +error::Error TestingCore::NewClientRequest( + TestingClient* client, std::string client_authenticated_id) { + LOG(VERBOSE) << "newclient " << core_->ID(); + + UntrustedMessage msg; + auto h2e_req = msg.mutable_h2e_request(); + auto new_client_req = h2e_req->mutable_new_client(); + auto request_id = next_request_id(); + open_client_requests_[request_id] = client; + h2e_req->set_request_id(request_id); + + new_client_req->set_client_authenticated_id(client_authenticated_id); + input_messages_.emplace_back(std::move(msg)); + return error::OK; +} + +// Backup or Restore +error::Error TestingCore::ExistingClientRequest(TestingClient* client, + uint64_t client_id, + std::string data) { + LOG(VERBOSE) << "existingclient " << core_->ID(); + + UntrustedMessage msg; + auto h2e_req = msg.mutable_h2e_request(); + auto existing_client_req = h2e_req->mutable_existing_client(); + auto request_id = next_request_id(); + open_client_requests_[request_id] = client; + h2e_req->set_request_id(request_id); + + existing_client_req->set_client_id(client_id); + existing_client_req->set_data(data); + input_messages_.emplace_back(std::move(msg)); + return error::OK; +} + +EnclaveReplicaStatus TestingCore::TakeExpectedEnclaveStatusReply() { + auto& h2e_response = h2e_responses_out_[0]; + EXPECT_EQ(h2e_response.inner_case(), HostToEnclaveResponse::kGetEnclaveStatusReply); + auto result = std::move(h2e_response.get_enclave_status_reply()); + h2e_responses_out_.pop_front(); + return result; +} + +}; // namespace svr2::core::test diff --git a/enclave/core/coretest/testingcore.h b/enclave/core/coretest/testingcore.h new file mode 100644 index 0000000..90ba2b0 --- /dev/null +++ b/enclave/core/coretest/testingcore.h @@ -0,0 +1,124 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + + +#ifndef __SVR2_CORE_CORETEST_TESTINGCORE_H__ +#define __SVR2_CORE_CORETEST_TESTINGCORE_H__ + +#include +#include +#include +#include +#include + +#include "core/core.h" +#include "env/test/test.h" +#include "proto/enclaveconfig.pb.h" +#include "proto/error.pb.h" +#include "proto/msgs.pb.h" +#include "util/log.h" +#include "proto/client.pb.h" + +namespace svr2::core::test { +class TestingCore; +class ReplicaGroup; +class TestingClient; +using RequestID = uint64_t; +using TestingCoreMap = std::map; +using PeerMessageMap = std::map>; +using OpenClientRequests = std::map; + +/* +This class wraps the basic actions of a `Core` and plays much +of the role the host plays in a real deployment - wrapping requests, +forwarding messages to peers and clients, etc. +*/ +class TestingCore { + enum class State { ACTIVE, PAUSED_SAVE_MSGS, PAUSED_DROP_MSGS, STOPPED }; + + public: + TestingCore(ReplicaGroup& replica_group); + + error::Error Init() { return error::OK; } + + uint64_t next_request_id() { return ++(next_request_id_); } + peerid::PeerID ID() const { return core_->ID(); } + const std::map>& peer_messages_out() + const { + return peer_messages_out_; + } + + const std::deque& host_to_enclave_responses() const { + return h2e_responses_out_; + } + + std::deque take_host_to_enclave_responses() { + return std::move(h2e_responses_out_); + } + const std::deque& input_messages() const { + return input_messages_; + } + + bool leader() const { return core_->leader() && active(); } + bool serving() const { return core_->serving() && active(); } + bool voting() const { return core_->voting() && active(); } + bool active() const { return state_ == State::ACTIVE; } + size_t num_voting() const { return core_->num_voting(); } + size_t num_serving() const { return core_->num_members(); } + std::set all_replicas() const { return core_->all_replicas(); } + + void Stop() { state_ = State::STOPPED; } + void Pause(bool drop_msgs) { + state_ = drop_msgs ? State::PAUSED_DROP_MSGS : State::PAUSED_SAVE_MSGS; + } + void Reactivate() { state_ = State::ACTIVE; } + + error::Error ProcessIncomingMessage(); + error::Error ProcessAllIncomingMessages(); + error::Error ForwardOutgoingMessages(); + error::Error ProcessNextH2EResponse(); + error::Error ProcessAllH2EResponses(); + + // Host to Enclave commands + error::Error ResetPeer(peerid::PeerID peer_id); + error::Error PingPeer(peerid::PeerID peer_id); + error::Error GetEnclaveStatus(); + error::Error TimerTick(); + error::Error CreateNewRaftGroup(); + error::Error JoinRaft(peerid::PeerID peer_id); + error::Error RequestVoting(); + error::Error Reconfigure(const enclaveconfig::EnclaveConfig& config); + error::Error DeleteBackup(const std::string& client_authenticated_id); + error::Error RaftRemoval(); + + // Peer communication + error::Error AddPeerMessage(PeerMessage&& peer_message); + + // Client communication + // handshake + error::Error NewClientRequest(TestingClient* client, + std::string client_authenticated_id); + + // Backup or Restore + error::Error ExistingClientRequest(TestingClient* client, uint64_t client_id, + std::string data); + + EnclaveReplicaStatus TakeExpectedEnclaveStatusReply(); + private: + std::unique_ptr core_; + ReplicaGroup& replica_group_; + enclaveconfig::EnclaveConfig config_; + + std::deque input_messages_; + std::deque h2e_responses_out_; + PeerMessageMap peer_messages_out_; + OpenClientRequests open_client_requests_; + + uint64_t next_request_id_{0}; + uint64_t timer_secs_{1}; + State state_{State::ACTIVE}; +}; + +}; // namespace svr2::core::test + +#endif // __SVR2_CORE_CORETEST_TESTINGCORE_H__ diff --git a/enclave/core/internal.h b/enclave/core/internal.h new file mode 100644 index 0000000..3c98e9f --- /dev/null +++ b/enclave/core/internal.h @@ -0,0 +1,76 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_CORE_INTERNAL_H__ +#define __SVR2_CORE_INTERNAL_H__ + +#include + +#include "raft/log.h" +#include "raft/raft.h" +#include "db/db.h" +#include "proto/e2e.pb.h" +#include "proto/msgs.pb.h" +#include "proto/raft.pb.h" + +namespace svr2::core::internal { + +typedef uint64_t TransactionID; + +struct WaitingForFirstConnection { + peerid::PeerID peer; + TransactionID join_tx; +}; +struct Loading { + enclaveconfig::RaftGroupConfig group_config; + raft::ReplicaGroup replica_group; + std::unique_ptr log; + std::unique_ptr db; + std::unique_ptr mem; + peerid::PeerID load_from; + TransactionID join_tx; + bool started; + uint64_t replication_id; + uint64_t replication_sequence; + std::string lexigraphically_largest_row_loaded_into_db; +}; +struct Loaded { + enclaveconfig::RaftGroupConfig group_config; + std::unique_ptr raft; + std::unique_ptr db; + raft::LogIdx db_last_applied_log; +}; +struct Raft { + Raft() { ClearState(); } + void ClearState() REQUIRES(mu) { + state = svr2::RAFTSTATE_NO_STATE; + waiting_for_first_connection = { + .peer = peerid::PeerID(), + .join_tx = 0, + }; + loading = { + .group_config = enclaveconfig::RaftGroupConfig(), + .replica_group = raft::ReplicaGroup(), + .log = nullptr, + .db = nullptr, + .join_tx = 0, + .started = false, + .replication_sequence = 0, + .lexigraphically_largest_row_loaded_into_db = "", + }; + loaded = { + .group_config = enclaveconfig::RaftGroupConfig(), + .raft = nullptr, + .db = nullptr, + .db_last_applied_log = 0, + }; + } + mutable util::mutex mu; // protects everything else in this struct. + RaftState state GUARDED_BY(mu); + WaitingForFirstConnection waiting_for_first_connection GUARDED_BY(mu); + Loading loading GUARDED_BY(mu); + Loaded loaded GUARDED_BY(mu); +}; + +} // namespace svr2::core::internal +#endif // __SVR2_CORE_INTERNAL_H__ diff --git a/enclave/core/tests/core.cc b/enclave/core/tests/core.cc new file mode 100644 index 0000000..e061487 --- /dev/null +++ b/enclave/core/tests/core.cc @@ -0,0 +1,2306 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +//TESTDEP gtest +//TESTDEP core/coretest +//TESTDEP core +//TESTDEP groupclock +//TESTDEP timeout +//TESTDEP client +//TESTDEP db +//TESTDEP raft +//TESTDEP peers +//TESTDEP peerid +//TESTDEP sender +//TESTDEP util +//TESTDEP context +//TESTDEP hmac +//TESTDEP noise +//TESTDEP noise-c +//TESTDEP noisewrap +//TESTDEP env +//TESTDEP env/test +//TESTDEP sip +//TESTDEP metrics +//TESTDEP proto +//TESTDEP protobuf-lite +//TESTDEP libsodium + +#include +#include +#include +#include +#include + +#include +#include + +#include "core/core.h" +#include "env/env.h" +#include "util/log.h" +#include "proto/enclaveconfig.pb.h" +#include "proto/e2e.pb.h" +#include "proto/client3.pb.h" +#include "noise/noise.h" +#include "env/test/test.h" +#include "util/bytes.h" +#include "db/db3.h" +#include "core/coretest/testingcore.h" +#include "core/coretest/replicagroup.h" +#include "core/coretest/testingclient.h" + +// This test is pretty large and contains a lot of code which should maybe be +// moved into some coretest library at a later date. There's a few very +// important functions in the CoreTest fixture: +// +// - PassMessages - pass a series of messages between multiple cores +// - ClientRequest - issue a client request and get back a response +// +// Both PassMessages and ClientRequest rely on a "CoreSet" of a group of cores +// that can pass messages to each other, and a "first" core, a core to which +// a starting message has just been sent and which should have put a first +// set of messages into env::test::SentMessages. +// +// Tests are then built on top of these functions. + +#define NOISE_OK(x) do { \ + int out = (x); \ + if (out != NOISE_ERROR_NONE) { \ + char buf[64]; \ + noise_strerror(out, buf, sizeof(buf)); \ + ASSERT_EQ(out, NOISE_ERROR_NONE) << "Noise error: " << buf; \ + } \ +} while (0) + +namespace svr2::core { +using svr2::core::test::TestingCore; +using svr2::core::test::ReplicaGroup; +using svr2::core::test::TestingClient; + +namespace { +struct ReplicaGroupConfig { + enclaveconfig::EnclaveConfig ecfg; + uint32_t min_voting; + uint32_t max_voting; + size_t initial_voting; + size_t initial_nonvoting; + size_t initial_nonmember; + + enclaveconfig::InitConfig init_config() const { + enclaveconfig::InitConfig cfg; + cfg.mutable_enclave_config()->MergeFrom(ecfg); + cfg.mutable_group_config()->set_db_version(enclaveconfig::DATABASE_VERSION_SVR2); + cfg.mutable_group_config()->set_min_voting_replicas(min_voting); + cfg.mutable_group_config()->set_max_voting_replicas(max_voting); + cfg.mutable_group_config()->set_attestation_timeout(3600); + return cfg; + } +}; + +enum class CoreRole { + Leader, + VotingNonLeader, + NonVoting +}; +}; + +class CoreTest : public ::testing::Test { + protected: + static void SetUpTestCase() { + env::Init(); + } + + HostToEnclaveResponse Response(std::vector msgs) { + CHECK(msgs.size() == 1); + CHECK(msgs[0].inner_case() == EnclaveMessage::kH2EResponse); + return std::move(*msgs[0].mutable_h2e_response()); + } + + std::vector SentMessages() { + return env::test::SentMessages(); + } + + void SetUp() { + ctx = &ctx_; + // clear sent messages. + SentMessages(); + valid_enclave_config.Clear(); + auto raft_config = valid_enclave_config.mutable_raft(); + raft_config->set_election_ticks(4); + raft_config->set_heartbeat_ticks(2); + raft_config->set_replication_chunk_bytes(1<<20); + raft_config->set_replica_voting_timeout_ticks(16); + raft_config->set_replica_membership_timeout_ticks(32); + raft_config->set_log_max_bytes(1<<20); + valid_enclave_config.set_e2e_txn_timeout_ticks(30); + valid_enclave_config.set_send_timestamp_ticks(10); + client_request = 10000; + valid_init_config.Clear(); + valid_init_config.mutable_enclave_config()->CopyFrom(valid_enclave_config); + valid_init_config.set_initial_timestamp_unix_secs(1); + valid_init_config.mutable_group_config()->set_db_version(enclaveconfig::DATABASE_VERSION_SVR2); + valid_init_config.mutable_group_config()->set_min_voting_replicas(1); + valid_init_config.mutable_group_config()->set_max_voting_replicas(5); + valid_init_config.mutable_group_config()->set_attestation_timeout(3600); + valid_init_config.mutable_group_config()->set_simulated(true); + } + + typedef std::map CoreMap; + typedef std::map> PassMessagesOut; + + // Passes back and forth all PeerMessage messages, and returns all non-PeerMessage + // messages, until there are no more messages to pass. The messages in SentMessages + // are considered to be from [first]. + PassMessagesOut PassMessages(const CoreMap& cores, Core* first) { + PassMessagesOut out; + bool quiescent = false; + std::map> to_send; + auto first_msgs = env::test::SentMessages(); + LOG(INFO) << "### starting message passing from " << first->ID() << " with " << first_msgs.size() << " messages"; + std::move(std::begin(first_msgs), std::end(first_msgs), std::back_inserter(to_send[first->ID()])); + while (to_send.size()) { + auto i = to_send.begin(); + const peerid::PeerID& from = i->first; + std::deque* msgs = &i->second; + if (msgs->size() == 0) { + to_send.erase(from); + continue; + } + EnclaveMessage msg = std::move(msgs->front()); + msgs->pop_front(); + if (msg.inner_case() != EnclaveMessage::kPeerMessage) { + LOG(INFO) << "# non-peer message from " << from; + out[from].push_back(std::move(msg)); + continue; + } + UntrustedMessage req; + *req.mutable_peer_message() = std::move(*msg.mutable_peer_message()); + peerid::PeerID to; + to.FromString(req.peer_message().peer_id()); + from.ToString(req.mutable_peer_message()->mutable_peer_id()); + context::Context ctx; + auto find = cores.find(to); + if (find == cores.end()) { + LOG(INFO) << "# offline recipient " << to; + out[from].push_back(std::move(msg)); + continue; + } + LOG(INFO) << "#####################################################"; + LOG(INFO) << "# peer message to " << to << " from " << from; + find->second->Receive(&ctx, req); + auto out_msgs = env::test::SentMessages(); + LOG(INFO) << "# yielded " << out_msgs.size(); + std::move(std::begin(out_msgs), std::end(out_msgs), std::back_inserter(to_send[to])); + } + LOG(INFO) << "### message passing complete"; + return out; + } + + uint64_t client_request; + + void ClientRequest(const CoreMap& cores, Core* core, const google::protobuf::MessageLite& req, google::protobuf::MessageLite* cli_resp, const std::string auth_id) { + // Set up client handshake. + NoiseHandshakeState* hsp; + NOISE_OK(noise_handshakestate_new_by_id(&hsp, &client::client_protocol, NOISE_ROLE_INITIATOR)); + noise::HandshakeState hs = noise::WrapHandshakeState(hsp); + + uint64_t client_id = 0; + { // Create new client + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(++client_request); + auto newc = host->mutable_new_client(); + newc->set_client_authenticated_id(auth_id); + context::Context ctx; + ASSERT_EQ(error::OK, core->Receive(&ctx, msg)); + auto out = PassMessages(cores, core); + ASSERT_EQ(out[core->ID()].size(), 1); + auto resp = out[core->ID()][0].h2e_response(); + ASSERT_EQ(resp.request_id(), client_request); + client_id = resp.new_client_reply().client_id(); + ASSERT_GT(client_id, 0); + + auto hs_msg = resp.new_client_reply().handshake_start(); + NOISE_OK(noise_dhstate_set_public_key( + noise_handshakestate_get_remote_public_key_dh(hsp), + noise::StrU8Ptr(hs_msg.test_only_pubkey()), + hs_msg.test_only_pubkey().size())); + } + NOISE_OK(noise_handshakestate_start(hsp)); + ASSERT_EQ(NOISE_ACTION_WRITE_MESSAGE, noise_handshakestate_get_action(hsp)); + + NoiseCipherState* txp; + NoiseCipherState* rxp; + { // Finish client handshake + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(++client_request); + auto ec = host->mutable_existing_client(); + ec->mutable_data()->resize(noise::HANDSHAKE_INIT_SIZE, '\0'); + NoiseBuffer write_buf = noise::BufferOutputFromString(ec->mutable_data()); + NOISE_OK(noise_handshakestate_write_message(hsp, &write_buf, nullptr)); + ec->mutable_data()->resize(write_buf.size, '\0'); + ec->set_client_id(client_id); + context::Context ctx; + ASSERT_EQ(error::OK, core->Receive(&ctx, msg)); + auto out = PassMessages(cores, core); + ASSERT_EQ(out[core->ID()].size(), 1); + auto resp = out[core->ID()][0].h2e_response(); + ASSERT_EQ(resp.request_id(), client_request); + ASSERT_EQ(NOISE_ACTION_READ_MESSAGE, noise_handshakestate_get_action(hsp)); + ASSERT_EQ(resp.status(), error::OK); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kExistingClientReply); + auto crep = resp.mutable_existing_client_reply(); + NoiseBuffer read_buf = noise::BufferInputFromString(crep->mutable_data()); + NOISE_OK(noise_handshakestate_read_message(hsp, &read_buf, nullptr)); + ASSERT_EQ(NOISE_ACTION_SPLIT, noise_handshakestate_get_action(hsp)); + NOISE_OK(noise_handshakestate_split(hsp, &txp, &rxp)); + } + noise::CipherState tx = noise::WrapCipherState(txp); + noise::CipherState rx = noise::WrapCipherState(rxp); + { // send the request, parse response. + std::string req_str; + ASSERT_TRUE(req.SerializeToString(&req_str)); + auto [ciphertext, encrypt_err] = noise::Encrypt(txp, req_str); + ASSERT_EQ(error::OK, encrypt_err); + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(++client_request); + auto ec = host->mutable_existing_client(); + ec->set_client_id(client_id); + ec->set_data(ciphertext); + context::Context ctx; + ASSERT_EQ(error::OK, core->Receive(&ctx, msg)); + auto out = PassMessages(cores, core); + ASSERT_EQ(out[core->ID()].size(), 1); + auto resp = out[core->ID()][0].h2e_response(); + ASSERT_EQ(resp.request_id(), client_request); + ASSERT_EQ(resp.status(), error::OK); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kExistingClientReply); + auto ec2 = resp.existing_client_reply(); + auto [plaintext, decrypt_err] = noise::Decrypt(rxp, ec2.data()); + ASSERT_EQ(error::OK, decrypt_err); + ASSERT_TRUE(cli_resp->ParseFromString(plaintext)); + } + } + + UntrustedMessage PeerMessage(const peerid::PeerID& from, const peerid::PeerID& to, EnclaveMessage msg) { + CHECK(msg.inner_case() == EnclaveMessage::kPeerMessage); + if (msg.peer_message().peer_id() != to.AsString()) { + peerid::PeerID id; + CHECK(error::OK == id.FromString(msg.peer_message().peer_id())); + LOG(ERROR) << "unexpected peer ID: " << id; + CHECK(false); + } + UntrustedMessage req; + *req.mutable_peer_message() = std::move(*msg.mutable_peer_message()); + from.ToString(req.mutable_peer_message()->mutable_peer_id()); + return req; + } + + enclaveconfig::EnclaveConfig valid_enclave_config; + enclaveconfig::InitConfig valid_init_config; + context::Context ctx_; + context::Context* ctx; +}; + +static void BackupRestoreTest(ReplicaGroupConfig cfg, CoreRole connect_to, bool drop_leader, std::map& partition) { + ReplicaGroup replica_group{}; + replica_group.Init(cfg.init_config(), cfg.initial_voting, cfg.initial_nonvoting, cfg.initial_nonmember); + + // tik tok + replica_group.TickTock(false); + replica_group.TickTock(false); + + auto [pin, e1] = util::StringToByteArray<32>("PIN45678901234567890123456789012"); + auto [secret, e2] = util::StringToByteArray<48>("SECRET78901234567890123456789012"); + ASSERT_TRUE(e1 == error::OK && e2 == error::OK); + + size_t core_num = 0; + switch(connect_to) { + case CoreRole::Leader: + core_num = 0; + break; + case CoreRole::VotingNonLeader: + ASSERT_TRUE(cfg.initial_voting > 1); + core_num = 1; + break; + case CoreRole::NonVoting: + ASSERT_TRUE(cfg.initial_nonvoting > 1); + core_num = cfg.initial_voting; + break; + } + + auto client_core = replica_group.get_core(core_num); + + // Block 1: Client requests backup + { + TestingClient cl(*client_core, "authenticated_id"); + + cl.RequestHandshake(); + replica_group.TickTock(false); + replica_group.TickTock(false); + + cl.RequestBackup(secret, pin, 10); + replica_group.TickTock(false); + + auto backup_response = cl.get_backup_response(); + ASSERT_NE(backup_response, nullptr); + LOG(INFO) << "created backup"; + ASSERT_EQ(backup_response->status(), client::BackupResponse::OK); + } + { + TestingClient cl(*client_core, "authenticated_id"); + cl.RequestHandshake(); + replica_group.TickTock(false); + replica_group.TickTock(false); + + cl.RequestExpose(secret); + replica_group.TickTock(false); + auto expose_response = cl.get_expose_response(); + ASSERT_NE(expose_response, nullptr); + LOG(INFO) << "backup expose"; + ASSERT_EQ(expose_response->status(), client::ExposeResponse::OK); + } + + // Now introduce problems if requested + if(drop_leader) { + auto leader = replica_group.get_core(replica_group.GroupLeaderIndex()); + leader->Pause(false); + } + replica_group.CreatePartition(partition); + + // run long enough to elect a new leader + for(size_t i = 0; i < 4*cfg.ecfg.raft().election_ticks(); ++i) { + replica_group.TickTock(false); + } + + // Block 2: Client requests restore + { + auto [main_partition, partition_size] = test::LargestPartition(partition); + switch(connect_to) { + case CoreRole::Leader: + core_num = replica_group.GroupLeaderIndex(); + break; + case CoreRole::VotingNonLeader: { + // Can't capture main_partition until C++20, need to assign it + auto maybe_it = std::find_if(partition.begin(), partition.end(), + [mp = main_partition, &replica_group, cfg](auto it) { + auto c = replica_group.get_core(it.first); + return it.second == mp + && !c->leader() + && c->voting();}); + // Make sure you put some voting members in the big partition or this + // will fail + ASSERT_NE(maybe_it, partition.end()); + core_num = maybe_it->first; + break; + } + case CoreRole::NonVoting: { + auto maybe_it = std::find_if(partition.begin(), partition.end(), + [mp = main_partition, &replica_group, cfg](auto it) { + auto c = replica_group.get_core(it.first); + return it.second == mp + && !c->voting() + && c->serving();}); + // Make sure you put some non-voting members in the big partition or + // this will fail! + ASSERT_NE(maybe_it, partition.end()); + core_num = maybe_it->first; + break; + } + } + + client_core = replica_group.get_core(core_num); + TestingClient cl(*client_core, "authenticated_id"); + + cl.RequestHandshake(); + replica_group.TickTock(false); + replica_group.TickTock(false); + + LOG(INFO) << "About to restore with core " << core_num << " (CoreRole: " << (int)connect_to << ")"; + cl.RequestRestore(pin); + replica_group.TickTock(false); + + auto restore_response = cl.get_restore_response(); + ASSERT_NE(restore_response, nullptr); + LOG(INFO) << "Super Secret: " << restore_response->data(); + ASSERT_EQ(util::ByteArrayToString(secret), restore_response->data()); + } +} + +static void WrongPINTest(ReplicaGroupConfig cfg, CoreRole connect_to, bool drop_leader, std::map& partition) { + + ReplicaGroup replica_group{}; + replica_group.Init(cfg.init_config(), cfg.initial_voting, cfg.initial_nonvoting, cfg.initial_nonmember); + + // tik tok + replica_group.TickTock(false); + replica_group.TickTock(false); + + auto [pin, e1] = util::StringToByteArray<32>("PIN45678901234567890123456789012"); + auto [wrong_pin, e2] = util::StringToByteArray<32>("SIN45678901234567890123456789012"); + auto [secret, e3] = util::StringToByteArray<48>("SECRET78901234567890123456789012"); + ASSERT_TRUE(e1 == error::OK && e2 == error::OK && e3 == error::OK); + size_t num_tries = 3; + + size_t core_num = 0; + switch(connect_to) { + case CoreRole::Leader: + core_num = 0; + break; + case CoreRole::VotingNonLeader: + ASSERT_TRUE(cfg.initial_voting > 1); + core_num = 1; + break; + case CoreRole::NonVoting: + ASSERT_TRUE(cfg.initial_nonvoting > 1); + core_num = cfg.initial_voting; + break; + } + + auto client_core = replica_group.get_core(core_num); + + // Block 1: Client requests backup + { + TestingClient cl(*client_core, "authenticated_id"); + + cl.RequestHandshake(); + replica_group.TickTock(false); + replica_group.TickTock(false); + + cl.RequestBackup(secret, pin, num_tries); + replica_group.TickTock(false); + + auto backup_response = cl.get_backup_response(); + ASSERT_NE(backup_response, nullptr); + LOG(INFO) << "created backup"; + } + { + TestingClient cl(*client_core, "authenticated_id"); + + cl.RequestHandshake(); + replica_group.TickTock(false); + replica_group.TickTock(false); + + cl.RequestExpose(secret); + replica_group.TickTock(false); + + auto expose_response = cl.get_expose_response(); + ASSERT_NE(expose_response, nullptr); + LOG(INFO) << "created backup"; + } + + // Now introduce problems if requested + if(drop_leader) { + auto leader = replica_group.get_core(replica_group.GroupLeaderIndex()); + leader->Pause(false); + } + replica_group.CreatePartition(partition); + + // run long enough to elect a new leader + for(size_t i = 0; i < 4*cfg.ecfg.raft().election_ticks(); ++i) { + replica_group.TickTock(false); + } + + // Block 2: Client requests restore with wrong pin + { + + auto [main_partition, partition_size] = test::LargestPartition(partition); + switch(connect_to) { + case CoreRole::Leader: + core_num = replica_group.GroupLeaderIndex(); + break; + case CoreRole::VotingNonLeader: { + // Can't capture main_partition until C++20, need to assign it + auto maybe_it = std::find_if(partition.begin(), partition.end(), + [mp = main_partition, &replica_group, cfg](auto it) { + auto c = replica_group.get_core(it.first); + return it.second == mp + && !c->leader() + && c->voting();}); + // Make sure you put some voting members in the big partition or this + // will fail + ASSERT_NE(maybe_it, partition.end()); + core_num = maybe_it->first; + break; + } + case CoreRole::NonVoting: { + auto maybe_it = std::find_if(partition.begin(), partition.end(), + [mp = main_partition, &replica_group, cfg](auto it) { + auto c = replica_group.get_core(it.first); + return it.second == mp + && !c->voting() + && c->serving();}); + // Make sure you put some non-voting members in the big partition or + // this will fail! + ASSERT_NE(maybe_it, partition.end()); + core_num = maybe_it->first; + break; + } + } + + client_core = replica_group.get_core(core_num); + TestingClient cl(*client_core, "authenticated_id"); + cl.RequestHandshake(); + replica_group.TickTock(false); + replica_group.TickTock(false); + + for(size_t i = 0; i < num_tries; ++i) { + cl.RequestRestore(wrong_pin); + replica_group.TickTock(false); + + auto restore_response = cl.get_restore_response(); + ASSERT_NE(restore_response, nullptr); + LOG(INFO) << "tries remaining: " << restore_response->tries() << " data: " << restore_response->data(); + ASSERT_NE(util::ByteArrayToString(secret), restore_response->data()); + } + + // now try correct PIN and confirm it is gone + cl.RequestRestore(pin); + replica_group.TickTock(false); + + auto restore_response = cl.get_restore_response(); + ASSERT_NE(restore_response, nullptr); + LOG(INFO) << "correct PIN tries remaining: " << restore_response->tries() << " data: " << restore_response->data(); + ASSERT_NE(util::ByteArrayToString(secret), restore_response->data()); + } +} + +void ConfirmWillNotServeClientRequests(ReplicaGroup& replica_group) { + auto leader = replica_group.get_leader_core(); + auto [pin, e1] = util::StringToByteArray<32>("PIN45678901234567890123456789012"); + auto [wrong_pin, e2] = util::StringToByteArray<32>("SIN45678901234567890123456789012"); + auto [secret, e3] = util::StringToByteArray<48>("SECRET78901234567890123456789012"); + ASSERT_TRUE(e1 == error::OK && e2 == error::OK && e3 == error::OK); + size_t num_tries = 3; + + //Client requests backup + TestingClient cl(*leader, "authenticated_id"); + + cl.RequestHandshake(); + // start the handshake + ASSERT_EQ(error::OK, leader->ProcessAllIncomingMessages()); + ASSERT_EQ(error::OK, leader->ProcessAllH2EResponses()); + //finish the handshake + ASSERT_EQ(error::OK, leader->ProcessAllIncomingMessages()); + ASSERT_EQ(error::OK, leader->ProcessAllH2EResponses()); + + cl.RequestBackup(secret, pin, num_tries); + ASSERT_EQ(error::OK, leader->ProcessAllIncomingMessages()); + + auto h2e_msgs = leader->take_host_to_enclave_responses(); + auto& h2e_response = h2e_msgs[0]; + ASSERT_EQ(h2e_response.inner_case(), HostToEnclaveResponse::kStatus); + ASSERT_EQ(h2e_response.status(), error::Core_NotEnoughVotingReplicas); +} + +void SelfHealingTest(ReplicaGroupConfig cfg) { + ReplicaGroup replica_group{}; + replica_group.Init(cfg.init_config(), cfg.initial_voting, cfg.initial_nonvoting, cfg.initial_nonmember); + + size_t initial_members = cfg.initial_nonvoting + cfg.initial_voting; + ASSERT_EQ(replica_group.num_voting(), 1); + + for(size_t i = 0; i < initial_members; ++i) { + if (replica_group.num_voting() < cfg.min_voting) { + ConfirmWillNotServeClientRequests(replica_group); + } + + replica_group.TickTock(false); + ASSERT_EQ(replica_group.num_voting(), std::min(2+i, initial_members)); + } + size_t num_voting = replica_group.num_voting(); + + // remove two voting non-leader + LOG(INFO) << "Removing two non-leader voting members"; + auto leader_core = replica_group.get_leader_core(); + TestingCore* non_leader_core = replica_group.get_voting_nonleader_core(); + if(non_leader_core != nullptr) { + LOG(INFO) << "STOPPING peer " << non_leader_core->ID() + << "(leader: " << leader_core->ID() << ")"; + non_leader_core->Pause(false); + } + non_leader_core = replica_group.get_voting_nonleader_core(); + if(non_leader_core != nullptr) { + LOG(INFO) << "STOPPING peer " << non_leader_core->ID() + << "(leader: " << leader_core->ID() << ")"; + non_leader_core->Pause(false); + } + + // even though we stopped them the replica group counts them as voting + ASSERT_EQ(replica_group.num_voting(), num_voting); + + // replica_membership_timeout_ticks is time to kick out a member + // replica_voting_timeout_ticks is time to demote from voting + // tick until it is demoted + auto voting_timeout_ticks = cfg.ecfg.raft().replica_voting_timeout_ticks(); + for(size_t i = 0; i < voting_timeout_ticks; ++i) { + LOG(INFO) << "\nTICK " << i << "\n"; + replica_group.TickTock(false); + } + LOG(INFO) << "NUM_VOTING before demotion: " << num_voting << " after demotion: " + << replica_group.num_voting(); + ASSERT_EQ(replica_group.num_voting(), num_voting - 1); + + // Tick again and eliminate the second core + replica_group.TickTock(false); + LOG(INFO) << "NUM_VOTING before demotion: " << num_voting << " after demotion: " + << replica_group.num_voting(); + ASSERT_EQ(replica_group.num_voting(), num_voting - 2); +} + +TEST_F(CoreTest, SelfHealingGrowthTest) { + ReplicaGroupConfig cfg = { + .ecfg = valid_enclave_config, + .min_voting = 5, + .max_voting = 9, + .initial_voting = 1, + .initial_nonvoting = 7, + .initial_nonmember = 2 + }; + SelfHealingTest(cfg); +} + +TEST_F(CoreTest, CreateReplicaGroup) { + ReplicaGroup replica_group{}; + ReplicaGroupConfig cfg{ + .ecfg = valid_enclave_config, + .min_voting = 1, + .max_voting = 1, + }; + replica_group.Init(cfg.init_config(), 5, 3, 2); + + // tik tok + replica_group.TickTock(false); + replica_group.TickTock(false); + + ASSERT_TRUE(replica_group.get_core(0)->leader()); + ASSERT_TRUE(replica_group.get_core(0)->serving()); + for(size_t i = 1; i < 8; ++i) { + ASSERT_FALSE(replica_group.get_core(i)->leader()); + ASSERT_TRUE(replica_group.get_core(i)->serving()); + } + + ASSERT_FALSE(replica_group.get_core(8)->leader()); + ASSERT_FALSE(replica_group.get_core(8)->serving()); + ASSERT_FALSE(replica_group.get_core(9)->leader()); + ASSERT_FALSE(replica_group.get_core(9)->serving()); + + replica_group.TickTock(false); + LOG(INFO) << "\nREMOVING LEADER\n" << " current leader: " << replica_group.GroupLeaderIndex() << "\n"; + + // Now take out the leader + replica_group.get_core(0)->Pause(false); + // tik tok until election should have happened + for(size_t i = 0; i < 4*valid_enclave_config.raft().election_ticks(); ++i) { + replica_group.TickTock(false); + } + ASSERT_TRUE( + replica_group.get_core(1)->leader() || + replica_group.get_core(2)->leader() || + replica_group.get_core(3)->leader() || + replica_group.get_core(4)->leader()); + + LOG(INFO) << "\nNEW LEADER\n" << " current leader: " << replica_group.GroupLeaderIndex() + << " (" << replica_group.GroupLeader() << ")\n"; +} + +TEST_F(CoreTest, TestPartition) { + ReplicaGroup replica_group{}; + ReplicaGroupConfig cfg{ + .ecfg = valid_enclave_config, + .min_voting = 1, + .max_voting = 1, + }; + replica_group.Init(cfg.init_config(), 5, 3, 2); + + // tik tok + replica_group.TickTock(false); + replica_group.TickTock(false); + + LOG(INFO) << "\nCREATE PARTITION\n" << " current leader: " << replica_group.GroupLeaderIndex() << "\n"; + replica_group.CreatePartition(std::map{ + {0,1}, {1,1}, {5,1}, {6,1}, // has leader but only one other voting member + {2,2}, {3,2}, {4,2}, {7,2}, {8,2}, {9,2} // no leader, but quorum of voting members. Should take over + }); + + // tik tok until election should have happened and completed (it might go through more than one cycle) + for(size_t i = 0; i < 4*valid_enclave_config.raft().election_ticks(); ++i) { + replica_group.TickTock(false); + } + + // add a voting member during the outage + LOG(INFO) << "Core 8 Joining"; + auto peer3_id = replica_group.get_core(3)->ID(); + ASSERT_EQ(error::OK, replica_group.get_core(8)->JoinRaft(peer3_id)); + replica_group.TickTock(false); + replica_group.TickTock(false); + + LOG(INFO) << "Request voting for core " << replica_group.get_core(8)->ID(); + ASSERT_EQ(error::OK, replica_group.get_core(8)->RequestVoting()); + // ignore errors because raft joing might have failed if, e.g., the + // load request was sent to a disconnected peer + replica_group.TickTock(true); + + LOG(INFO) << "\nCLEAR PARTITION\n"; + replica_group.ClearPartition(); + replica_group.ForwardBlockedMessages(); + replica_group.TickTock(true); + // replica_group.ClearBlockedMessages(); // This will drop all messages and leave replicas stuck in-flight until self-healing + replica_group.PassMessagesUntilQuiet(); + + // for(size_t i = 0; i < 2*valid_enclave_config.raft().election_ticks(); ++i) { + // replica_group.TickTock(false); + // } + + LOG(INFO) << "\nNEW LEADER\n" << " current leader: " << replica_group.GroupLeaderIndex() + << " (" << replica_group.GroupLeader() << ")\n"; + for(size_t i = 0; i < 10; ++i) { + LOG(INFO) << "replica " << i << " (" << replica_group.get_core(i)->ID() + << ") is_leader: " << replica_group.get_core(i)->leader() + << ") serving: " << replica_group.get_core(i)->serving(); + } + + ASSERT_TRUE( + replica_group.get_core(2)->leader() || + replica_group.get_core(3)->leader() || + replica_group.get_core(4)->leader()); +} + +TEST_F(CoreTest, BackupRestorePartitionNetworkTest) { + ReplicaGroupConfig cfg = { + .ecfg = valid_enclave_config, + .min_voting = 1, + .max_voting = 1, + .initial_voting = 5, + .initial_nonvoting = 3, + .initial_nonmember = 2 + }; + std::map partition = { + {0,1}, {1,1}, {5,1}, {6,1}, // has leader but only one other voting member + {2,2}, {3,2}, {4,2}, {7,2}, {8,2}, {9,2} // no leader, but quorum of voting members. Should take over + }; + BackupRestoreTest(cfg, CoreRole::Leader, false, partition); + BackupRestoreTest(cfg, CoreRole::VotingNonLeader, false, partition); + BackupRestoreTest(cfg, CoreRole::NonVoting, false, partition); +} + +TEST_F(CoreTest, WrongPINPartitionNetworkTest) { + ReplicaGroupConfig cfg = { + .ecfg = valid_enclave_config, + .min_voting = 1, + .max_voting = 1, + .initial_voting = 5, + .initial_nonvoting = 3, + .initial_nonmember = 2 + }; + std::map partition = { + {0,1}, {1,1}, {5,1}, {6,1}, // has leader but only one other voting member + {2,2}, {3,2}, {4,2}, {7,2}, {8,2}, {9,2} // no leader, but quorum of voting members. Should take over + }; + + // TODO: Consider parameterized tests (http://google.github.io/googletest/reference/testing.html#TEST_P) + WrongPINTest(cfg, CoreRole::Leader, false, partition); + WrongPINTest(cfg, CoreRole::VotingNonLeader, false, partition); + WrongPINTest(cfg, CoreRole::NonVoting, false, partition); +} + +TEST_F(CoreTest, BackupRestoreHealthyNetworkTest) { + ReplicaGroupConfig cfg = { + .ecfg = valid_enclave_config, + .min_voting = 1, + .max_voting = 1, + .initial_voting = 5, + .initial_nonvoting = 3, + .initial_nonmember = 2 + }; + // no partition in network + std::map partition = { + {0,1}, {1,1}, {5,1}, {6,1}, {9,1}, + {2,1}, {3,1}, {4,1}, {7,1}, {8,1} + }; + BackupRestoreTest(cfg, CoreRole::Leader, false, partition); + BackupRestoreTest(cfg, CoreRole::VotingNonLeader, false, partition); + BackupRestoreTest(cfg, CoreRole::NonVoting, false, partition); +} + +TEST_F(CoreTest, WrongPINHealthyNetworkTest) { + ReplicaGroupConfig cfg = { + .ecfg = valid_enclave_config, + .min_voting = 1, + .max_voting = 1, + .initial_voting = 5, + .initial_nonvoting = 3, + .initial_nonmember = 2 + }; + // no partition in network + std::map partition = { + {0,1}, {1,1}, {5,1}, {6,1}, {9,1}, + {2,1}, {3,1}, {4,1}, {7,1}, {8,1} + }; + WrongPINTest(cfg, CoreRole::Leader, false, partition); + WrongPINTest(cfg, CoreRole::VotingNonLeader, false, partition); + WrongPINTest(cfg, CoreRole::NonVoting, false, partition); +} + +TEST_F(CoreTest, BackupRestoreDropLeaderTest) { + ReplicaGroupConfig cfg = { + .ecfg = valid_enclave_config, + .min_voting = 1, + .max_voting = 1, + .initial_voting = 5, + .initial_nonvoting = 3, + .initial_nonmember = 2 + }; + // no partition in network + std::map partition = { + {0,1}, {1,1}, {5,1}, {6,1}, {9,1}, + {2,1}, {3,1}, {4,1}, {7,1}, {8,1} + }; + BackupRestoreTest(cfg, CoreRole::Leader, true, partition); + BackupRestoreTest(cfg, CoreRole::VotingNonLeader, true, partition); + BackupRestoreTest(cfg, CoreRole::NonVoting, true, partition); +} + +TEST_F(CoreTest, WrongPINDropLeaderTest) { + ReplicaGroupConfig cfg = { + .ecfg = valid_enclave_config, + .min_voting = 1, + .max_voting = 1, + .initial_voting = 5, + .initial_nonvoting = 3, + .initial_nonmember = 2 + }; + // no partition in network + std::map partition = { + {0,1}, {1,1}, {5,1}, {6,1}, {9,1}, + {2,1}, {3,1}, {4,1}, {7,1}, {8,1} + }; + WrongPINTest(cfg, CoreRole::Leader, true, partition); + WrongPINTest(cfg, CoreRole::VotingNonLeader, true, partition); + WrongPINTest(cfg, CoreRole::NonVoting, true, partition); +} + +TEST_F(CoreTest, EnclaveStatus) { + ReplicaGroup replica_group{}; + ReplicaGroupConfig cfg{ + .ecfg = valid_enclave_config, + .min_voting = 1, + .max_voting = 1, + }; + replica_group.Init(cfg.init_config(), 5, 2, 0); + + // get status from leader and follower + auto leader = replica_group.get_core(0); + auto follower = replica_group.get_core(1); + ASSERT_EQ(error::OK, leader->ProcessAllH2EResponses()); + ASSERT_EQ(error::OK, follower->ProcessAllH2EResponses()); + leader->GetEnclaveStatus(); + follower->GetEnclaveStatus(); + + replica_group.PassMessagesUntilQuiet(); + auto leader_status = leader->TakeExpectedEnclaveStatusReply(); + auto follower_status = follower->TakeExpectedEnclaveStatusReply(); + + // drop the leader, have a new election, and try again + leader->Pause(false); + + // run long enough to elect a new leader + for(size_t i = 0; i < 4*valid_enclave_config.raft().election_ticks(); ++i) { + replica_group.TickTock(false); + } + + leader = replica_group.get_core(replica_group.GroupLeaderIndex()); + if(follower->leader()) { + follower = replica_group.get_core(2); + } + + ASSERT_EQ(error::OK, leader->ProcessAllH2EResponses()); + ASSERT_EQ(error::OK, follower->ProcessAllH2EResponses()); + leader->GetEnclaveStatus(); + follower->GetEnclaveStatus(); + + replica_group.PassMessagesUntilQuiet(); + leader_status = leader->TakeExpectedEnclaveStatusReply(); + follower_status = follower->TakeExpectedEnclaveStatusReply(); +} + +TEST_F(CoreTest, ClientRequests) { + auto [core, err] = Core::Create(ctx, valid_init_config); + ASSERT_TRUE(core->ID().Valid()); + CoreMap cores; + cores[core->ID()] = core.get(); + + { // Set up as one-replica Raft + UntrustedMessage msg; + + auto host = msg.mutable_h2e_request(); + host->set_request_id(999); + host->set_create_new_raft_group(true); + + context::Context ctx; + ASSERT_EQ(error::OK, core->Receive(&ctx, msg)); + auto resp = Response(env::test::SentMessages()); + ASSERT_EQ(resp.request_id(), 999); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kStatus); + ASSERT_EQ(resp.status(), error::OK); + } + + LOG(INFO) << "sending backup request"; + + client::Request req; + auto b = req.mutable_backup(); + b->set_data("12345678901234567890123456789012"); + b->set_pin("12345678901234567890123456789012"); + b->set_max_tries(10); + client::Response resp; + ClientRequest(cores, core.get(), req, &resp, "backup7890123456"); + ASSERT_EQ(client::Response::kBackup, resp.inner_case()); + ASSERT_EQ(client::BackupResponse::OK, resp.backup().status()); + + LOG(INFO) << "sending expose request"; + + client::Request req2; + auto a = req2.mutable_expose(); + a->set_data("12345678901234567890123456789012"); + client::Response resp2; + ClientRequest(cores, core.get(), req2, &resp2, "backup7890123456"); + ASSERT_EQ(client::Response::kExpose, resp2.inner_case()); + ASSERT_EQ(client::ExposeResponse::OK, resp2.expose().status()); + + LOG(INFO) << "sending restore request"; + + client::Request req3; + auto r = req3.mutable_restore(); + r->set_pin("12345678901234567890123456789012"); + client::Response resp3; + ClientRequest(cores, core.get(), req3, &resp3, "backup7890123456"); + ASSERT_EQ(client::Response::kRestore, resp3.inner_case()); + ASSERT_EQ(client::RestoreResponse::OK, resp3.restore().status()); +} + +TEST_F(CoreTest, RestoreWithoutExpose) { + auto [core, err] = Core::Create(ctx, valid_init_config); + ASSERT_TRUE(core->ID().Valid()); + CoreMap cores; + cores[core->ID()] = core.get(); + + { // Set up as one-replica Raft + UntrustedMessage msg; + + auto host = msg.mutable_h2e_request(); + host->set_request_id(999); + host->set_create_new_raft_group(true); + + context::Context ctx; + ASSERT_EQ(error::OK, core->Receive(&ctx, msg)); + auto resp = Response(env::test::SentMessages()); + ASSERT_EQ(resp.request_id(), 999); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kStatus); + ASSERT_EQ(resp.status(), error::OK); + } + + LOG(INFO) << "sending backup request"; + + client::Request req; + auto b = req.mutable_backup(); + b->set_data("12345678901234567890123456789012"); + b->set_pin("12345678901234567890123456789012"); + b->set_max_tries(10); + client::Response resp; + ClientRequest(cores, core.get(), req, &resp, "backup7890123456"); + ASSERT_EQ(client::Response::kBackup, resp.inner_case()); + ASSERT_EQ(client::BackupResponse::OK, resp.backup().status()); + + LOG(INFO) << "sending restore request"; + + client::Request req3; + auto r = req3.mutable_restore(); + r->set_pin("12345678901234567890123456789012"); + client::Response resp3; + ClientRequest(cores, core.get(), req3, &resp3, "backup7890123456"); + ASSERT_EQ(client::Response::kRestore, resp3.inner_case()); + ASSERT_EQ(client::RestoreResponse::MISSING, resp3.restore().status()); +} + +TEST_F(CoreTest, MultiNodeRaft) { + auto [core1, err1] = Core::Create(ctx, valid_init_config); + ASSERT_EQ(err1, error::OK); + auto [core2, err2] = Core::Create(ctx, valid_init_config); + ASSERT_EQ(err2, error::OK); + auto [core3, err3] = Core::Create(ctx, valid_init_config); + ASSERT_EQ(err3, error::OK); + LOG(INFO) << "core1=" << core1->ID() << ", core2=" << core2->ID() << ", core3=" << core3->ID(); + + // Create cores map for PassMessages + CoreMap cores; + cores[core1->ID()] = core1.get(); + cores[core2->ID()] = core2.get(); + cores[core3->ID()] = core3.get(); + + { + LOG(INFO) << "\n\nSet up as one-replica Raft on core 1"; + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(1000); + host->set_create_new_raft_group(true); + + context::Context ctx; + ASSERT_EQ(error::OK, core1->Receive(&ctx, msg)); + auto out = env::test::SentMessages(); + ASSERT_EQ(1, out.size()); + auto resp = out[0].h2e_response(); + ASSERT_EQ(resp.request_id(), 1000); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kStatus); + ASSERT_EQ(resp.status(), error::OK); + } + + { + LOG(INFO) << "\n\nRequest join on core 2"; + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(1001); + auto req = host->mutable_join_raft(); + core1->ID().ToString(req->mutable_peer_id()); + + context::Context ctx; + ASSERT_EQ(error::OK, core2->Receive(&ctx, msg)); + auto out = PassMessages(cores, core2.get()); + ASSERT_EQ(1, out[core2->ID()].size()); + auto resp = out[core2->ID()][0].h2e_response(); + ASSERT_EQ(resp.request_id(), 1001); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kStatus); + ASSERT_EQ(resp.status(), error::OK); + } + + { + LOG(INFO) << "\n\nRequest core2 vote"; + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(1002); + host->set_request_voting(true); + + context::Context ctx; + ASSERT_EQ(error::OK, core2->Receive(&ctx, msg)); + auto out = PassMessages(cores, core2.get()); + ASSERT_EQ(1, out[core2->ID()].size()); + auto resp = out[core2->ID()][0].h2e_response(); + ASSERT_EQ(resp.request_id(), 1002); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kStatus); + ASSERT_EQ(resp.status(), error::OK); + } + + { + LOG(INFO) << "\n\nRequest join on core 3"; + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(1003); + auto req = host->mutable_join_raft(); + core1->ID().ToString(req->mutable_peer_id()); + + context::Context ctx; + ASSERT_EQ(error::OK, core3->Receive(&ctx, msg)); + auto out = PassMessages(cores, core3.get()); + ASSERT_EQ(1, out[core3->ID()].size()); + auto resp = out[core3->ID()][0].h2e_response(); + ASSERT_EQ(resp.request_id(), 1003); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kStatus); + ASSERT_EQ(resp.status(), error::OK); + } + + { + LOG(INFO) << "\n\nRequest core3 vote"; + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(1004); + host->set_request_voting(true); + + context::Context ctx; + ASSERT_EQ(error::OK, core3->Receive(&ctx, msg)); + auto out = PassMessages(cores, core3.get()); + ASSERT_EQ(1, out[core3->ID()].size()); + auto resp = out[core3->ID()][0].h2e_response(); + ASSERT_EQ(resp.request_id(), 1004); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kStatus); + ASSERT_EQ(resp.status(), error::OK); + } + + EXPECT_TRUE(core1->serving()); + EXPECT_TRUE(core1->leader()); + EXPECT_TRUE(core2->serving()); + EXPECT_FALSE(core2->leader()); + EXPECT_TRUE(core3->serving()); + EXPECT_FALSE(core3->leader()); + + LOG(INFO) << "\n\nRequest to leader core1"; + client::Request req; + auto b = req.mutable_backup(); + b->set_data("12345678901234567890123456789012"); + b->set_pin("12345678901234567890123456789012"); + b->set_max_tries(10); + client::Response resp; + ClientRequest(cores, core1.get(), req, &resp, "backup7890123456"); + ASSERT_EQ(client::Response::kBackup, resp.inner_case()); + ASSERT_EQ(client::BackupResponse::OK, resp.backup().status()); + + LOG(INFO) << "\n\nElecting next leader"; + const int max_attempts = 100; + cores.erase(core1->ID()); // core1 goes offline + for (int i = 0; i < max_attempts && !core2->leader(); i++) { + LOG(INFO) << "core2 tick"; + UntrustedMessage msg; + msg.mutable_timer_tick()->set_new_timestamp_unix_secs(i); + + context::Context ctx; + ASSERT_EQ(error::OK, core2->Receive(&ctx, msg)); + PassMessages(cores, core2.get()); + } + EXPECT_TRUE(core2->serving()); + EXPECT_TRUE(core2->leader()); + EXPECT_TRUE(core3->serving()); + EXPECT_FALSE(core3->leader()); + + LOG(INFO) << "\n\nRequest to leader core2"; + + client::Request req2; + auto r = req2.mutable_expose(); + r->set_data("12345678901234567890123456789012"); + client::Response resp2; + ClientRequest(cores, core2.get(), req2, &resp2, "backup7890123456"); + ASSERT_EQ(client::Response::kExpose, resp2.inner_case()); + ASSERT_EQ(client::ExposeResponse::OK, resp2.expose().status()); + + LOG(INFO) << "\n\nRequest to non-leader core3"; + + client::Request req3; + auto r3 = req3.mutable_restore(); + r3->set_pin("12345678901234567890123456789012"); + client::Response resp3; + ClientRequest(cores, core3.get(), req3, &resp3, "backup7890123456"); + ASSERT_EQ(client::Response::kRestore, resp3.inner_case()); + ASSERT_EQ(client::RestoreResponse::OK, resp3.restore().status()); +} + +TEST_F(CoreTest, RejectsUnsetHostTransactionID) { + auto [core, err] = Core::Create(ctx, valid_init_config); + ASSERT_EQ(err, error::OK); + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + // host->set_request_id(1004); // Not set, should error out. + host->set_get_enclave_status(true); + context::Context ctx; + err = core->Receive(&ctx, msg); + ASSERT_EQ(err, error::Core_HostToEnclaveTransactionID); +} + +TEST_F(CoreTest, MultiJoinCausesDisconnectedPeersWhichThenConnect) { + ReplicaGroup replica_group; + ReplicaGroupConfig cfg = { + .ecfg = valid_enclave_config, + .min_voting = 2, + .max_voting = 3, + }; + replica_group.Init( + cfg.init_config(), + 1, // initial_voting + 0, // initial_nonvoting + 2); // initial_nonmember + // By issuing two relatively simultaneous JoinRaft connections, we + // achieve a state where, during the raft Joining protocol, both cores + // 1 and 2 create peer connections to core 0, but they do not establish + // a peer connection to each other. + LOG(INFO) << "Sending joins"; + replica_group.get_core(1)->JoinRaft(replica_group.get_core(0)->ID()); + replica_group.get_core(2)->JoinRaft(replica_group.get_core(0)->ID()); + LOG(INFO) << "Processing messages"; + ASSERT_EQ(error::OK, replica_group.PassMessagesUntilQuiet()); + LOG(INFO) << "Requesting voting"; + replica_group.get_core(1)->RequestVoting(); + ASSERT_EQ(error::OK, replica_group.PassMessagesUntilQuiet()); + replica_group.get_core(2)->RequestVoting(); + ASSERT_EQ(error::OK, replica_group.PassMessagesUntilQuiet()); + LOG(INFO) << "Partitioning"; + replica_group.CreatePartition(std::map{ + {0,1}, + {1,2}, {2,2}, + }); + // What should happen now is that, as part of one of these ticks, + // nodes 1 and 2 should detect that they're not connected to each + // other and establish a connection. In doing so, they make it possible + // for themselves to run a leader election, and one of them should + // be elected leader. + for (int i = 0; i < valid_enclave_config.raft().election_ticks() * 4; i++) { + LOG(INFO) << "Tick " << i; + replica_group.TickTock(2, false); + } + EXPECT_TRUE(replica_group.get_core(1)->leader() || replica_group.get_core(2)->leader()); +} + +TEST_F(CoreTest, SetLogLevel) { + auto old_log_level = ::svr2::util::log_level_to_write; + auto [core, err] = Core::Create(ctx, valid_init_config); + ASSERT_EQ(err, error::OK); + + { + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(100); + host->set_set_log_level(::svr2::enclaveconfig::LOG_LEVEL_MAX); + context::Context ctx; + ASSERT_EQ(error::OK, core->Receive(&ctx, msg)); + auto resp = Response(env::test::SentMessages()); + ASSERT_EQ(resp.request_id(), 100); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kStatus); + ASSERT_EQ(resp.status(), error::Core_InvalidLogLevel); + } + { + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(101); + host->set_set_log_level(::svr2::enclaveconfig::LOG_LEVEL_WARNING); + context::Context ctx; + ASSERT_EQ(error::OK, core->Receive(&ctx, msg)); + auto resp = Response(env::test::SentMessages()); + ASSERT_EQ(resp.request_id(), 101); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kStatus); + ASSERT_EQ(resp.status(), error::OK); + ASSERT_EQ(::svr2::util::log_level_to_write, ::svr2::enclaveconfig::LOG_LEVEL_WARNING); + util::SetLogLevel(old_log_level); + ASSERT_EQ(::svr2::util::log_level_to_write, old_log_level); + } +} + +TEST_F(CoreTest, ResetPeer){ + ReplicaGroup replica_group{}; + size_t initial_voting = 4; + size_t initial_nonvoting = 0; + size_t initial_nonmember = 0; + ReplicaGroupConfig cfg = { + .ecfg = valid_enclave_config, + .min_voting = 2, + .max_voting = 3, + }; + replica_group.Init( + cfg.init_config(), + initial_voting, + initial_nonvoting, + initial_nonmember); + + // get status from leader and follower + auto leader = replica_group.get_core(0); + auto follower = replica_group.get_core(1); + ASSERT_EQ(error::OK, leader->ResetPeer(follower->ID())); + replica_group.PassMessagesUntilQuiet(); + LOG(INFO) << "Reset peer"; + + ASSERT_EQ(error::OK, leader->ProcessAllH2EResponses()); + ASSERT_EQ(error::OK, follower->ProcessAllH2EResponses()); + leader->GetEnclaveStatus(); + follower->GetEnclaveStatus(); + + replica_group.PassMessagesUntilQuiet(); + auto leader_status = leader->TakeExpectedEnclaveStatusReply(); + auto follower_status = follower->TakeExpectedEnclaveStatusReply(); + + for(size_t i = 0; i < leader_status.peers_size(); ++i) { + auto peer_status = leader_status.peers(i); + peerid::PeerID pid; + ASSERT_EQ(error::OK, pid.FromString(peer_status.peer_id())); + if(pid == follower->ID()) { + ASSERT_EQ(PEER_DISCONNECTED, peer_status.connection_status().state()); + } + } + + replica_group.TickTock(false); + replica_group.TickTock(false); + + ASSERT_EQ(error::OK, leader->ProcessAllH2EResponses()); + ASSERT_EQ(error::OK, follower->ProcessAllH2EResponses()); + leader->GetEnclaveStatus(); + follower->GetEnclaveStatus(); + + replica_group.PassMessagesUntilQuiet(); + leader_status = leader->TakeExpectedEnclaveStatusReply(); + follower_status = follower->TakeExpectedEnclaveStatusReply(); + + for(size_t i = 0; i < leader_status.peers_size(); ++i) { + auto peer_status = leader_status.peers(i); + peerid::PeerID pid; + ASSERT_EQ(error::OK, pid.FromString(peer_status.peer_id())); + if(pid == follower->ID()) { + ASSERT_EQ(PEER_CONNECTED, peer_status.connection_status().state()); + } + } +} + +TEST_F(CoreTest, ReplicatingRowsWithMultiplePackets) { + enclaveconfig::InitConfig config = valid_init_config; + config.mutable_enclave_config()->mutable_raft()->set_replication_chunk_bytes(10 * 1024); // holds ~17 logs + auto [core1, err1] = Core::Create(ctx, config); + ASSERT_EQ(err1, error::OK); + auto [core2, err2] = Core::Create(ctx, config); + ASSERT_EQ(err2, error::OK); + LOG(INFO) << "core1=" << core1->ID() << ", core2=" << core2->ID(); + + // Create cores map for PassMessages + CoreMap cores; + cores[core1->ID()] = core1.get(); + cores[core2->ID()] = core2.get(); + + { + LOG(INFO) << "\n\nSet up as one-replica Raft on core 1"; + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(1000); + host->set_create_new_raft_group(true); + + context::Context ctx; + ASSERT_EQ(error::OK, core1->Receive(&ctx, msg)); + auto out = env::test::SentMessages(); + ASSERT_EQ(1, out.size()); + auto resp = out[0].h2e_response(); + ASSERT_EQ(resp.request_id(), 1000); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kStatus); + ASSERT_EQ(resp.status(), error::OK); + } + + for (uint64_t i = 0; i < 100; i++) { // more logs than fit in replication_chunk_bytes + LOG(INFO) << "\n\nRequest to leader core1"; + client::Request req; + std::array backup_id = {0}; + util::BigEndian64Bytes(i, backup_id.data()); + + auto b = req.mutable_backup(); + b->set_data("12345678901234567890123456789012"); + b->set_pin("12345678901234567890123456789012"); + b->set_max_tries(10); + client::Response resp; + ClientRequest(cores, core1.get(), req, &resp, util::ByteArrayToString(backup_id)); + ASSERT_EQ(client::Response::kBackup, resp.inner_case()); + ASSERT_EQ(client::BackupResponse::OK, resp.backup().status()); + } + + { + LOG(INFO) << "\n\nRequest join on core 2"; + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(1001); + auto req = host->mutable_join_raft(); + core1->ID().ToString(req->mutable_peer_id()); + + context::Context ctx; + ASSERT_EQ(error::OK, core2->Receive(&ctx, msg)); + auto out = PassMessages(cores, core2.get()); + ASSERT_EQ(1, out[core2->ID()].size()); + auto resp = out[core2->ID()][0].h2e_response(); + ASSERT_EQ(resp.request_id(), 1001); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kStatus); + ASSERT_EQ(resp.status(), error::OK); + } + + { + LOG(INFO) << "\n\nRequest core2 vote"; + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(1002); + host->set_request_voting(true); + + context::Context ctx; + ASSERT_EQ(error::OK, core2->Receive(&ctx, msg)); + auto out = PassMessages(cores, core2.get()); + ASSERT_EQ(1, out[core2->ID()].size()); + auto resp = out[core2->ID()][0].h2e_response(); + ASSERT_EQ(resp.request_id(), 1002); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kStatus); + ASSERT_EQ(resp.status(), error::OK); + } + + EXPECT_TRUE(core1->serving()); + EXPECT_TRUE(core1->leader()); + EXPECT_TRUE(core2->serving()); + EXPECT_FALSE(core2->leader()); +} + +TEST_F(CoreTest, ReplicatingRowsWithTruncatedLog) { + enclaveconfig::InitConfig config = valid_init_config; + config.mutable_enclave_config()->mutable_raft()->set_log_max_bytes(10240); // truncate log quickly + auto [core1, err1] = Core::Create(ctx, config); + ASSERT_EQ(err1, error::OK); + auto [core2, err2] = Core::Create(ctx, config); + ASSERT_EQ(err2, error::OK); + LOG(INFO) << "core1=" << core1->ID() << ", core2=" << core2->ID(); + + // Create cores map for PassMessages + CoreMap cores; + cores[core1->ID()] = core1.get(); + cores[core2->ID()] = core2.get(); + + { + LOG(INFO) << "\n\nSet up as one-replica Raft on core 1"; + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(1000); + host->set_create_new_raft_group(true); + + context::Context ctx; + ASSERT_EQ(error::OK, core1->Receive(&ctx, msg)); + auto out = env::test::SentMessages(); + ASSERT_EQ(1, out.size()); + auto resp = out[0].h2e_response(); + ASSERT_EQ(resp.request_id(), 1000); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kStatus); + ASSERT_EQ(resp.status(), error::OK); + } + + for (uint64_t i = 0; i < 100; i++) { // more logs than fit in replication_chunk_bytes + LOG(INFO) << "\n\nRequest to leader core1"; + client::Request req; + std::array backup_id = {0}; + util::BigEndian64Bytes(i, backup_id.data()); + + auto b = req.mutable_backup(); + b->set_data("12345678901234567890123456789012"); + b->set_pin("12345678901234567890123456789012"); + b->set_max_tries(10); + client::Response resp; + ClientRequest(cores, core1.get(), req, &resp, util::ByteArrayToString(backup_id)); + ASSERT_EQ(client::Response::kBackup, resp.inner_case()); + ASSERT_EQ(client::BackupResponse::OK, resp.backup().status()); + } + + { + LOG(INFO) << "\n\nRequest join on core 2"; + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(1001); + auto req = host->mutable_join_raft(); + core1->ID().ToString(req->mutable_peer_id()); + + context::Context ctx; + ASSERT_EQ(error::OK, core2->Receive(&ctx, msg)); + auto out = PassMessages(cores, core2.get()); + ASSERT_EQ(1, out[core2->ID()].size()); + auto resp = out[core2->ID()][0].h2e_response(); + ASSERT_EQ(resp.request_id(), 1001); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kStatus); + ASSERT_EQ(resp.status(), error::OK); + } + + { + LOG(INFO) << "\n\nRequest core2 vote"; + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(1002); + host->set_request_voting(true); + + context::Context ctx; + ASSERT_EQ(error::OK, core2->Receive(&ctx, msg)); + auto out = PassMessages(cores, core2.get()); + ASSERT_EQ(1, out[core2->ID()].size()); + auto resp = out[core2->ID()][0].h2e_response(); + ASSERT_EQ(resp.request_id(), 1002); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kStatus); + ASSERT_EQ(resp.status(), error::OK); + } + + EXPECT_TRUE(core1->serving()); + EXPECT_TRUE(core1->leader()); + EXPECT_TRUE(core2->serving()); + EXPECT_FALSE(core2->leader()); +} + +TEST_F(CoreTest, RaftRemoval) { + ReplicaGroupConfig cfg = { + .ecfg = valid_enclave_config, + .min_voting = 2, + .max_voting = 3, + .initial_voting = 3, + .initial_nonvoting = 0, + .initial_nonmember = 0, + }; + ReplicaGroup replica_group{}; + replica_group.Init(cfg.init_config(), cfg.initial_voting, cfg.initial_nonvoting, cfg.initial_nonmember); + EXPECT_TRUE(replica_group.get_core(0)->leader()); + EXPECT_TRUE(replica_group.get_core(1)->active()); + EXPECT_TRUE(replica_group.get_core(1)->voting()); + + LOG(INFO) << "================================== REMOVING " << replica_group.get_core(1)->ID(); + replica_group.get_core(1)->RaftRemoval(); + replica_group.PassMessagesUntilQuiet(); + EXPECT_TRUE(replica_group.get_core(0)->leader()); + EXPECT_EQ(0, replica_group.get_core(0)->all_replicas().count(replica_group.get_core(1)->ID())); + EXPECT_EQ(0, replica_group.get_core(2)->all_replicas().count(replica_group.get_core(1)->ID())); + // Keeping these tests in here for illustrative purposes: + // Core 1 has been removed from Raft by this point, but it doesn't KNOW that it + // has, because part of being removed is that it no longer receives Raft logs. + EXPECT_TRUE(replica_group.get_core(1)->active()); + EXPECT_TRUE(replica_group.get_core(1)->voting()); +} + +TEST_F(CoreTest, RaftRemovalOfLeaderFails) { + ReplicaGroupConfig cfg = { + .ecfg = valid_enclave_config, + .min_voting = 2, + .max_voting = 3, + .initial_voting = 3, + .initial_nonvoting = 0, + .initial_nonmember = 0, + }; + ReplicaGroup replica_group{}; + replica_group.Init(cfg.init_config(), cfg.initial_voting, cfg.initial_nonvoting, cfg.initial_nonmember); + EXPECT_TRUE(replica_group.get_core(0)->leader()); + EXPECT_TRUE(replica_group.get_core(1)->active()); + EXPECT_TRUE(replica_group.get_core(1)->voting()); + replica_group.get_core(0)->RaftRemoval(); + replica_group.PassMessagesUntilQuiet(); + EXPECT_TRUE(replica_group.get_core(0)->leader()); + EXPECT_EQ(1, replica_group.get_core(0)->all_replicas().count(replica_group.get_core(1)->ID())); + EXPECT_EQ(error::Core_LeaderRemovingSelf, replica_group.ProcessAllH2EResponses()); +} + +TEST_F(CoreTest, Hashes2) { + auto [core, err] = Core::Create(ctx, valid_init_config); + ASSERT_TRUE(core->ID().Valid()); + CoreMap cores; + cores[core->ID()] = core.get(); + + { // Set up as one-replica Raft + UntrustedMessage msg; + + auto host = msg.mutable_h2e_request(); + host->set_request_id(999); + host->set_create_new_raft_group(true); + + context::Context ctx; + ASSERT_EQ(error::OK, core->Receive(&ctx, msg)); + auto resp = Response(env::test::SentMessages()); + ASSERT_EQ(resp.request_id(), 999); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kStatus); + ASSERT_EQ(resp.status(), error::OK); + } + + LOG(INFO) << "sending backup request"; + + client::Request req; + auto b = req.mutable_backup(); + b->set_data("12345678901234567890123456789012"); + b->set_pin("12345678901234567890123456789012"); + b->set_max_tries(10); + client::Response resp; + ClientRequest(cores, core.get(), req, &resp, "backup7890123456"); + ASSERT_EQ(client::Response::kBackup, resp.inner_case()); + ASSERT_EQ(client::BackupResponse::OK, resp.backup().status()); + + LOG(INFO) << "sending expose request"; + + client::Request req2; + auto a = req2.mutable_expose(); + a->set_data("12345678901234567890123456789012"); + client::Response resp2; + ClientRequest(cores, core.get(), req2, &resp2, "backup7890123456"); + ASSERT_EQ(client::Response::kExpose, resp2.inner_case()); + ASSERT_EQ(client::ExposeResponse::OK, resp2.expose().status()); + + LOG(INFO) << "sending restore request"; + + client::Request req3; + auto r = req3.mutable_restore(); + r->set_pin("12345678901234567890123456789012"); + client::Response resp3; + ClientRequest(cores, core.get(), req3, &resp3, "backup7890123456"); + ASSERT_EQ(client::Response::kRestore, resp3.inner_case()); + ASSERT_EQ(client::RestoreResponse::OK, resp3.restore().status()); + + { + UntrustedMessage msg; + + auto host = msg.mutable_h2e_request(); + host->set_request_id(10101); + host->set_hashes(true); + + context::Context ctx; + ASSERT_EQ(error::OK, core->Receive(&ctx, msg)); + auto resp = Response(env::test::SentMessages()); + ASSERT_EQ(resp.request_id(), 10101); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kHashes); + ASSERT_EQ(resp.status(), error::OK); + EXPECT_EQ(util::BigEndian64FromBytes(reinterpret_cast(resp.hashes().db_hash().data())), + 5883775926529965153ULL); + EXPECT_EQ(resp.hashes().commit_idx(), 4); + EXPECT_EQ(util::BigEndian64FromBytes(reinterpret_cast(resp.hashes().commit_hash_chain().data())), + 8788200018387288622ULL); + } +} + +TEST_F(CoreTest, ReplicationRandom) { + for (int test_i = 0; test_i < 10; test_i++) { + enclaveconfig::InitConfig config = valid_init_config; + config.mutable_enclave_config()->mutable_raft()->set_log_max_bytes(10240); + auto [core1, err1] = Core::Create(ctx, config); + ASSERT_EQ(err1, error::OK); + auto [core2, err2] = Core::Create(ctx, config); + ASSERT_EQ(err2, error::OK); + LOG(INFO) << "core1=" << core1->ID() << ", core2=" << core2->ID(); + + // Create cores map for PassMessages + CoreMap cores; + cores[core1->ID()] = core1.get(); + cores[core2->ID()] = core2.get(); + + { + LOG(INFO) << "\n\nSet up as one-replica Raft on core 1"; + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(1000); + host->set_create_new_raft_group(true); + + context::Context ctx; + ASSERT_EQ(error::OK, core1->Receive(&ctx, msg)); + auto out = env::test::SentMessages(); + ASSERT_EQ(1, out.size()); + auto resp = out[0].h2e_response(); + ASSERT_EQ(resp.request_id(), 1000); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kStatus); + ASSERT_EQ(resp.status(), error::OK); + } + + for (uint64_t i = 0; i < 100; i++) { // more logs than fit in replication_chunk_bytes + LOG(INFO) << "\n\nRequest to leader core1"; + client::Request req; + std::array backup_id = {0}; + // Randomly order inserts + util::BigEndian64Bytes(rand(), backup_id.data()); + + auto b = req.mutable_backup(); + b->set_data("12345678901234567890123456789012"); + b->set_pin("12345678901234567890123456789012"); + b->set_max_tries(10); + client::Response resp; + ClientRequest(cores, core1.get(), req, &resp, util::ByteArrayToString(backup_id)); + ASSERT_EQ(client::Response::kBackup, resp.inner_case()); + ASSERT_EQ(client::BackupResponse::OK, resp.backup().status()); + } + + { + LOG(INFO) << "\n\nRequest join on core 2"; + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(1001); + auto req = host->mutable_join_raft(); + core1->ID().ToString(req->mutable_peer_id()); + + context::Context ctx; + ASSERT_EQ(error::OK, core2->Receive(&ctx, msg)); + auto out = PassMessages(cores, core2.get()); + ASSERT_EQ(1, out[core2->ID()].size()); + auto resp = out[core2->ID()][0].h2e_response(); + ASSERT_EQ(resp.request_id(), 1001); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kStatus); + ASSERT_EQ(resp.status(), error::OK); + } + + { + LOG(INFO) << "\n\nRequest hashes"; + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(1099); + host->set_hashes(true); + context::Context ctx; + + ASSERT_EQ(error::OK, core1->Receive(&ctx, msg)); + auto resp1 = Response(env::test::SentMessages()); + ASSERT_EQ(resp1.inner_case(), HostToEnclaveResponse::kHashes); + ASSERT_EQ(resp1.status(), error::OK); + + ASSERT_EQ(error::OK, core2->Receive(&ctx, msg)); + auto resp2 = Response(env::test::SentMessages()); + ASSERT_EQ(resp2.inner_case(), HostToEnclaveResponse::kHashes); + ASSERT_EQ(resp2.status(), error::OK); + + EXPECT_EQ(resp1.hashes().db_hash(), resp2.hashes().db_hash()); + EXPECT_EQ(resp1.hashes().commit_idx(), resp2.hashes().commit_idx()); + EXPECT_EQ(resp1.hashes().commit_hash_chain(), resp2.hashes().commit_hash_chain()); + } + } +} + +static e2e::ReplicateStatePush MakeReplicateStatePush( + uint64_t repl_id, + uint64_t seq, + uint64_t first_log, + size_t logs, + bool db_to_end, + size_t rows) { + e2e::ReplicateStatePush p; + p.set_replication_id(repl_id); + p.set_replication_sequence(seq); + p.set_first_log_idx(first_log); + for (size_t i = 0; i < logs; i++) { + p.add_entries(); + } + p.set_db_to_end(db_to_end); + for (size_t i = 0; i < rows; i++) { + p.add_rows(); + } + return p; +} + +static void ReplicateStatePushMatches(const e2e::ReplicateStatePush& a, const e2e::ReplicateStatePush& b) { + LOG(INFO) << "Testing replication ID " << a.replication_id() << "/" << b.replication_id() << " seq " << a.replication_sequence() << "/" << b.replication_sequence(); + EXPECT_EQ(a.replication_id(), b.replication_id()); + EXPECT_EQ(a.replication_sequence(), b.replication_sequence()); + EXPECT_EQ(a.first_log_idx(), b.first_log_idx()); + EXPECT_EQ(a.entries_size(), b.entries_size()); + EXPECT_EQ(a.db_to_end(), b.db_to_end()); + EXPECT_EQ(a.rows_size(), b.rows_size()); +} + +TEST_F(CoreTest, Replicator) { + enclaveconfig::InitConfig cfg = valid_init_config; + cfg.mutable_enclave_config()->mutable_raft()->set_replication_chunk_bytes(10240); + cfg.mutable_enclave_config()->mutable_raft()->set_replication_pipeline(3); + auto [core, err] = Core::Create(ctx, cfg); + ASSERT_TRUE(core->ID().Valid()); + CoreMap cores; + cores[core->ID()] = core.get(); + peers::PeerManager pm; + ASSERT_EQ(error::OK, pm.Init(ctx)); + + LOG(INFO) << "\n\nCreating Raft group"; + { + UntrustedMessage msg; + + auto host = msg.mutable_h2e_request(); + host->set_request_id(999); + host->set_create_new_raft_group(true); + + context::Context ctx; + ASSERT_EQ(error::OK, core->Receive(&ctx, msg)); + auto resp = Response(env::test::SentMessages()); + ASSERT_EQ(resp.request_id(), 999); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kStatus); + ASSERT_EQ(resp.status(), error::OK); + } + + LOG(INFO) << "\n\nAdding initial rows"; + for (uint8_t i = 0; i < 200; i++) { + client::Request req; + std::array backup_id = {i, 0}; + + auto b = req.mutable_backup(); + b->set_data("12345678901234567890123456789012"); + b->set_pin("12345678901234567890123456789012"); + b->set_max_tries(10); + client::Response resp; + ClientRequest(cores, core.get(), req, &resp, util::ByteArrayToString(backup_id)); + ASSERT_EQ(client::Response::kBackup, resp.inner_case()); + ASSERT_EQ(client::BackupResponse::OK, resp.backup().status()); + } + + LOG(INFO) << "\n\nConnecting Core to PeerManager"; + { + ASSERT_EQ(error::OK, pm.ConnectToPeer(ctx, core->ID())); + auto msgs = env::test::SentMessages(); + ASSERT_EQ(1, msgs.size()); + ASSERT_EQ(error::OK, core->Receive(ctx, PeerMessage(pm.ID(), core->ID(), msgs[0]))); + msgs = env::test::SentMessages(); + ASSERT_EQ(2, msgs.size()); // synack + timestamp + e2e::EnclaveToEnclaveMessage* e2e; + ASSERT_EQ(error::OK, pm.RecvFromPeer(ctx, PeerMessage(core->ID(), pm.ID(), msgs[0]).peer_message(), &e2e)); + ASSERT_NE(e2e, nullptr); + ASSERT_EQ(e2e->inner_case(), e2e::EnclaveToEnclaveMessage::kConnected); + ASSERT_EQ(error::OK, pm.RecvFromPeer(ctx, PeerMessage(core->ID(), pm.ID(), msgs[1]).peer_message(), &e2e)); + ASSERT_NE(e2e, nullptr); + ASSERT_EQ(e2e->inner_case(), e2e::EnclaveToEnclaveMessage::kTransactionRequest); + } + + LOG(INFO) << "\n\nGetRaft"; + uint64_t group_id; + uint64_t repl_id = 1; + { + e2e::EnclaveToEnclaveMessage msg; + auto txn = msg.mutable_transaction_request(); + txn->set_request_id(1); + txn->set_get_raft(true); + ASSERT_EQ(error::OK, pm.SendToPeer(ctx, core->ID(), msg)); + auto msgs = env::test::SentMessages(); + ASSERT_EQ(1, msgs.size()); + ASSERT_EQ(error::OK, core->Receive(ctx, PeerMessage(pm.ID(), core->ID(), msgs[0]))); + msgs = env::test::SentMessages(); + ASSERT_EQ(1, msgs.size()); // synack + timestamp + e2e::EnclaveToEnclaveMessage* e2e; + ASSERT_EQ(error::OK, pm.RecvFromPeer(ctx, PeerMessage(core->ID(), pm.ID(), msgs[0]).peer_message(), &e2e)); + ASSERT_NE(e2e, nullptr); + ASSERT_EQ(e2e->inner_case(), e2e::EnclaveToEnclaveMessage::kTransactionResponse); + ASSERT_EQ(e2e->transaction_response().inner_case(), e2e::TransactionResponse::kGetRaft); + group_id = e2e->transaction_response().get_raft().group_config().group_id(); + ASSERT_NE(group_id, 0); + } + LOG(INFO) << "\n\nReplicateReq"; + std::string last_backup_id = ""; + std::deque txns; + { + e2e::EnclaveToEnclaveMessage msg; + auto txn = msg.mutable_transaction_request(); + txn->set_request_id(12345); + txn->mutable_replicate_state()->set_group_id(group_id); + txn->mutable_replicate_state()->set_replication_id(repl_id); + ASSERT_EQ(error::OK, pm.SendToPeer(ctx, core->ID(), msg)); + auto msgs = env::test::SentMessages(); + ASSERT_EQ(1, msgs.size()); + ASSERT_EQ(error::OK, core->Receive(ctx, PeerMessage(pm.ID(), core->ID(), msgs[0]))); + msgs = env::test::SentMessages(); + std::vector expected_pipeline = { + MakeReplicateStatePush(repl_id, 0, 1, 80, false, 0), + MakeReplicateStatePush(repl_id, 1, 81, 79, false, 0), + MakeReplicateStatePush(repl_id, 2, 160, 42, false, 44), + }; + ASSERT_EQ(expected_pipeline.size(), msgs.size()); // pipelining + for (size_t i = 0; i < expected_pipeline.size(); i++) { + e2e::EnclaveToEnclaveMessage* e2e; + ASSERT_EQ(error::OK, pm.RecvFromPeer(ctx, PeerMessage(core->ID(), pm.ID(), msgs[i]).peer_message(), &e2e)); + ASSERT_NE(e2e, nullptr); + ASSERT_EQ(e2e->inner_case(), e2e::EnclaveToEnclaveMessage::kTransactionRequest); + ASSERT_EQ(e2e->transaction_request().inner_case(), e2e::TransactionRequest::kReplicateStatePush); + ReplicateStatePushMatches( + e2e->transaction_request().replicate_state_push(), + expected_pipeline[i]); + txns.push_back(e2e->transaction_request().request_id()); + for (const auto& row : e2e->transaction_request().replicate_state_push().rows()) { + e2e::DB2RowState rs; + ASSERT_TRUE(rs.ParseFromString(row)); + ASSERT_LT(last_backup_id, rs.backup_id()); + last_backup_id = rs.backup_id(); + } + } + } + LOG(INFO) << "\n\nAdding intermediate rows"; + for (uint8_t i = 0; i < 200; i++) { + client::Request req; + std::array backup_id = {i, 1}; + + auto b = req.mutable_backup(); + b->set_data("12345678901234567890123456789012"); + b->set_pin("12345678901234567890123456789012"); + b->set_max_tries(10); + client::Response resp; + ClientRequest(cores, core.get(), req, &resp, util::ByteArrayToString(backup_id)); + ASSERT_EQ(client::Response::kBackup, resp.inner_case()); + ASSERT_EQ(client::BackupResponse::OK, resp.backup().status()); + } + std::deque expected_pushes = { + MakeReplicateStatePush(repl_id, 3, 202, 79, false, 0), + MakeReplicateStatePush(repl_id, 4, 281, 79, false, 0), + MakeReplicateStatePush(repl_id, 5, 360, 42, false, 44), + MakeReplicateStatePush(repl_id, 6, 402, 0, false, 97), + MakeReplicateStatePush(repl_id, 7, 402, 0, false, 97), + MakeReplicateStatePush(repl_id, 8, 402, 0, true, 75), + }; + while (expected_pushes.size()) { + e2e::EnclaveToEnclaveMessage msg; + auto txn = msg.mutable_transaction_response(); + ASSERT_GT(txns.size(), 0); + txn->set_request_id(txns.front()); + txns.pop_front(); + txn->set_status(error::OK); + ASSERT_EQ(error::OK, pm.SendToPeer(ctx, core->ID(), msg)); + auto msgs = env::test::SentMessages(); + ASSERT_EQ(1, msgs.size()); + ASSERT_EQ(error::OK, core->Receive(ctx, PeerMessage(pm.ID(), core->ID(), msgs[0]))); + msgs = env::test::SentMessages(); + ASSERT_EQ(msgs.size(), 1); + e2e::EnclaveToEnclaveMessage* e2e; + ASSERT_EQ(error::OK, pm.RecvFromPeer(ctx, PeerMessage(core->ID(), pm.ID(), msgs[0]).peer_message(), &e2e)); + ASSERT_NE(e2e, nullptr); + ASSERT_EQ(e2e->inner_case(), e2e::EnclaveToEnclaveMessage::kTransactionRequest); + ASSERT_EQ(e2e->transaction_request().inner_case(), e2e::TransactionRequest::kReplicateStatePush); + ReplicateStatePushMatches( + e2e->transaction_request().replicate_state_push(), + expected_pushes.front()); + expected_pushes.pop_front(); + txns.push_back(e2e->transaction_request().request_id()); + for (const auto& row : e2e->transaction_request().replicate_state_push().rows()) { + e2e::DB2RowState rs; + ASSERT_TRUE(rs.ParseFromString(row)); + ASSERT_LT(last_backup_id, rs.backup_id()); + last_backup_id = rs.backup_id(); + } + } + while (txns.size() > 1) { + e2e::EnclaveToEnclaveMessage msg; + auto txn = msg.mutable_transaction_response(); + txn->set_request_id(txns.front()); + txns.pop_front(); + txn->set_status(error::OK); + ASSERT_EQ(error::OK, pm.SendToPeer(ctx, core->ID(), msg)); + auto msgs = env::test::SentMessages(); + ASSERT_EQ(1, msgs.size()); + ASSERT_EQ(error::OK, core->Receive(ctx, PeerMessage(pm.ID(), core->ID(), msgs[0]))); + msgs = env::test::SentMessages(); + ASSERT_EQ(msgs.size(), 0); + } + { + e2e::EnclaveToEnclaveMessage msg; + auto txn = msg.mutable_transaction_response(); + ASSERT_EQ(txns.size(), 1); + txn->set_request_id(txns.front()); + txns.pop_front(); + txn->set_status(error::OK); + ASSERT_EQ(error::OK, pm.SendToPeer(ctx, core->ID(), msg)); + auto msgs = env::test::SentMessages(); + ASSERT_EQ(1, msgs.size()); + ASSERT_EQ(error::OK, core->Receive(ctx, PeerMessage(pm.ID(), core->ID(), msgs[0]))); + msgs = env::test::SentMessages(); + ASSERT_EQ(msgs.size(), 1); + e2e::EnclaveToEnclaveMessage* e2e; + ASSERT_EQ(error::OK, pm.RecvFromPeer(ctx, PeerMessage(core->ID(), pm.ID(), msgs[0]).peer_message(), &e2e)); + ASSERT_NE(e2e, nullptr); + ASSERT_EQ(e2e->inner_case(), e2e::EnclaveToEnclaveMessage::kTransactionResponse); + ASSERT_EQ(e2e->transaction_response().inner_case(), e2e::TransactionResponse::kStatus); + ASSERT_EQ(e2e->transaction_response().status(), error::OK); + ASSERT_EQ(e2e->transaction_response().request_id(), 12345); + } +} + + +TEST_F(CoreTest, BackupResetsNumTries) { + + ReplicaGroupConfig cfg = { + .ecfg = valid_enclave_config, + .min_voting = 1, + .max_voting = 1, + .initial_voting = 3, + .initial_nonvoting = 0, + .initial_nonmember = 0 + }; + ReplicaGroup replica_group{}; + replica_group.Init(cfg.init_config(), cfg.initial_voting, cfg.initial_nonvoting, cfg.initial_nonmember); + + // tik tok + replica_group.TickTock(false); + replica_group.TickTock(false); + + auto [pin, e1] = util::StringToByteArray<32>("PIN45678901234567890123456789012"); + auto [wrong_pin, e2] = util::StringToByteArray<32>("SIN45678901234567890123456789012"); + auto [secret, e3] = util::StringToByteArray<48>("SECRET78901234567890123456789012"); + ASSERT_TRUE(e1 == error::OK && e2 == error::OK && e3 == error::OK); + size_t num_tries = 3; + + size_t core_num = 0; // connect to the leader + auto client_core = replica_group.get_core(core_num); + + // Client requests backup + { + TestingClient cl(*client_core, "authenticated_id"); + + cl.RequestHandshake(); + replica_group.TickTock(false); + replica_group.TickTock(false); + + cl.RequestBackup(secret, pin, num_tries); + replica_group.TickTock(false); + + auto backup_response = cl.get_backup_response(); + ASSERT_NE(backup_response, nullptr); + LOG(INFO) << "created backup"; + } + + // Client requests expose + { + TestingClient cl(*client_core, "authenticated_id"); + + cl.RequestHandshake(); + replica_group.TickTock(false); + replica_group.TickTock(false); + + cl.RequestExpose(secret); + replica_group.TickTock(false); + + auto expose_response = cl.get_expose_response(); + ASSERT_NE(expose_response, nullptr); + LOG(INFO) << "created expose"; + } + + // Client requests restore with wrong pin + { + client_core = replica_group.get_core(core_num); + TestingClient cl(*client_core, "authenticated_id"); + cl.RequestHandshake(); + replica_group.TickTock(false); + replica_group.TickTock(false); + cl.RequestRestore(wrong_pin); + replica_group.TickTock(false); + + auto restore_response = cl.get_restore_response(); + ASSERT_NE(restore_response, nullptr); + ASSERT_EQ(restore_response->tries(), num_tries - 1); + ASSERT_NE(util::ByteArrayToString(secret), restore_response->data()); + + } + + + // Client requests backup again + { + TestingClient cl(*client_core, "authenticated_id"); + + cl.RequestHandshake(); + replica_group.TickTock(false); + replica_group.TickTock(false); + + cl.RequestBackup(secret, pin, num_tries); + replica_group.TickTock(false); + + auto backup_response = cl.get_backup_response(); + ASSERT_NE(backup_response, nullptr); + LOG(INFO) << "created backup"; + } + + // Client requests expose again + { + TestingClient cl(*client_core, "authenticated_id"); + + cl.RequestHandshake(); + replica_group.TickTock(false); + replica_group.TickTock(false); + + cl.RequestExpose(secret); + replica_group.TickTock(false); + + auto expose_response = cl.get_expose_response(); + ASSERT_NE(expose_response, nullptr); + LOG(INFO) << "created expose"; + } + + // Client requests restore again and checks that the number of tries is correct + { + client_core = replica_group.get_core(core_num); + TestingClient cl(*client_core, "authenticated_id"); + cl.RequestHandshake(); + replica_group.TickTock(false); + replica_group.TickTock(false); + cl.RequestRestore(pin); + replica_group.TickTock(false); + + auto restore_response = cl.get_restore_response(); + ASSERT_NE(restore_response, nullptr); + ASSERT_EQ(restore_response->tries(), num_tries); + ASSERT_EQ(util::ByteArrayToString(secret), restore_response->data()); + + } +} + +TEST_F(CoreTest, MultiNodeRaftSVR3) { + enclaveconfig::InitConfig config = valid_init_config; + config.mutable_group_config()->set_db_version(enclaveconfig::DATABASE_VERSION_SVR3); + auto [core1, err1] = Core::Create(ctx, config); + ASSERT_EQ(err1, error::OK); + auto [core2, err2] = Core::Create(ctx, config); + ASSERT_EQ(err2, error::OK); + auto [core3, err3] = Core::Create(ctx, config); + ASSERT_EQ(err3, error::OK); + LOG(INFO) << "core1=" << core1->ID() << ", core2=" << core2->ID() << ", core3=" << core3->ID(); + + // Create cores map for PassMessages + CoreMap cores; + cores[core1->ID()] = core1.get(); + cores[core2->ID()] = core2.get(); + cores[core3->ID()] = core3.get(); + + { + LOG(INFO) << "\n\nSet up as one-replica Raft on core 1"; + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(1000); + host->set_create_new_raft_group(true); + + context::Context ctx; + ASSERT_EQ(error::OK, core1->Receive(&ctx, msg)); + auto out = env::test::SentMessages(); + ASSERT_EQ(1, out.size()); + auto resp = out[0].h2e_response(); + ASSERT_EQ(resp.request_id(), 1000); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kStatus); + ASSERT_EQ(resp.status(), error::OK); + } + + { + LOG(INFO) << "\n\nRequest join on core 2"; + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(1001); + auto req = host->mutable_join_raft(); + core1->ID().ToString(req->mutable_peer_id()); + + context::Context ctx; + ASSERT_EQ(error::OK, core2->Receive(&ctx, msg)); + auto out = PassMessages(cores, core2.get()); + ASSERT_EQ(1, out[core2->ID()].size()); + auto resp = out[core2->ID()][0].h2e_response(); + ASSERT_EQ(resp.request_id(), 1001); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kStatus); + ASSERT_EQ(resp.status(), error::OK); + } + + { + LOG(INFO) << "\n\nRequest core2 vote"; + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(1002); + host->set_request_voting(true); + + context::Context ctx; + ASSERT_EQ(error::OK, core2->Receive(&ctx, msg)); + auto out = PassMessages(cores, core2.get()); + ASSERT_EQ(1, out[core2->ID()].size()); + auto resp = out[core2->ID()][0].h2e_response(); + ASSERT_EQ(resp.request_id(), 1002); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kStatus); + ASSERT_EQ(resp.status(), error::OK); + } + + { + LOG(INFO) << "\n\nRequest join on core 3"; + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(1003); + auto req = host->mutable_join_raft(); + core1->ID().ToString(req->mutable_peer_id()); + + context::Context ctx; + ASSERT_EQ(error::OK, core3->Receive(&ctx, msg)); + auto out = PassMessages(cores, core3.get()); + ASSERT_EQ(1, out[core3->ID()].size()); + auto resp = out[core3->ID()][0].h2e_response(); + ASSERT_EQ(resp.request_id(), 1003); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kStatus); + ASSERT_EQ(resp.status(), error::OK); + } + + { + LOG(INFO) << "\n\nRequest core3 vote"; + UntrustedMessage msg; + auto host = msg.mutable_h2e_request(); + host->set_request_id(1004); + host->set_request_voting(true); + + context::Context ctx; + ASSERT_EQ(error::OK, core3->Receive(&ctx, msg)); + auto out = PassMessages(cores, core3.get()); + ASSERT_EQ(1, out[core3->ID()].size()); + auto resp = out[core3->ID()][0].h2e_response(); + ASSERT_EQ(resp.request_id(), 1004); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kStatus); + ASSERT_EQ(resp.status(), error::OK); + } + + EXPECT_TRUE(core1->serving()); + EXPECT_TRUE(core1->leader()); + EXPECT_TRUE(core2->serving()); + EXPECT_FALSE(core2->leader()); + EXPECT_TRUE(core3->serving()); + EXPECT_FALSE(core3->leader()); + + LOG(INFO) << "\n\nRequest to leader core1"; + client::Request3 req; + auto b = req.mutable_create(); + b->set_max_tries(10); + b->mutable_blinded_element()->resize(db::DB3::ELEMENT_SIZE); + crypto_core_ristretto255_random( + reinterpret_cast(b->mutable_blinded_element()->data())); + + client::Response3 resp; + ClientRequest(cores, core1.get(), req, &resp, "backup7890123456"); + ASSERT_EQ(client::Response3::kCreate, resp.inner_case()); + ASSERT_EQ(client::CreateResponse::OK, resp.create().status()); + + LOG(INFO) << "\n\nElecting next leader"; + const int max_attempts = 100; + cores.erase(core1->ID()); // core1 goes offline + for (int i = 0; i < max_attempts && !core2->leader(); i++) { + LOG(INFO) << "core2 tick"; + UntrustedMessage msg; + msg.mutable_timer_tick()->set_new_timestamp_unix_secs(i); + + context::Context ctx; + ASSERT_EQ(error::OK, core2->Receive(&ctx, msg)); + PassMessages(cores, core2.get()); + } + EXPECT_TRUE(core2->serving()); + EXPECT_TRUE(core2->leader()); + EXPECT_TRUE(core3->serving()); + EXPECT_FALSE(core3->leader()); + + LOG(INFO) << "\n\nRequest to leader core2"; + + client::Request3 req2; + auto r = req2.mutable_evaluate(); + r->set_blinded_element(req.create().blinded_element()); + client::Response3 resp2; + ClientRequest(cores, core2.get(), req2, &resp2, "backup7890123456"); + ASSERT_EQ(client::Response3::kEvaluate, resp2.inner_case()); + ASSERT_EQ(client::EvaluateResponse::OK, resp2.evaluate().status()); + ASSERT_EQ(resp.create().evaluated_element(), resp2.evaluate().evaluated_element()); +} + +TEST_F(CoreTest, Hashes3) { + enclaveconfig::InitConfig config = valid_init_config; + config.mutable_group_config()->set_db_version(enclaveconfig::DATABASE_VERSION_SVR3); + auto [core, err] = Core::Create(ctx, config); + ASSERT_TRUE(core->ID().Valid()); + CoreMap cores; + cores[core->ID()] = core.get(); + + std::string blinded; + blinded.resize(db::DB3::ELEMENT_SIZE); + crypto_core_ristretto255_random( + reinterpret_cast(blinded.data())); + + { // Set up as one-replica Raft + UntrustedMessage msg; + + auto host = msg.mutable_h2e_request(); + host->set_request_id(999); + host->set_create_new_raft_group(true); + + context::Context ctx; + ASSERT_EQ(error::OK, core->Receive(&ctx, msg)); + auto resp = Response(env::test::SentMessages()); + ASSERT_EQ(resp.request_id(), 999); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kStatus); + ASSERT_EQ(resp.status(), error::OK); + } + + LOG(INFO) << "sending backup request"; + + client::Request3 req; + auto b = req.mutable_create(); + b->set_blinded_element(blinded); + b->set_max_tries(10); + client::Response3 resp; + ClientRequest(cores, core.get(), req, &resp, "backup7890123456"); + ASSERT_EQ(client::Response3::kCreate, resp.inner_case()); + ASSERT_EQ(client::CreateResponse::OK, resp.create().status()); + + { + UntrustedMessage msg; + + auto host = msg.mutable_h2e_request(); + host->set_request_id(10101); + host->set_hashes(true); + + context::Context ctx; + ASSERT_EQ(error::OK, core->Receive(&ctx, msg)); + auto resp = Response(env::test::SentMessages()); + ASSERT_EQ(resp.request_id(), 10101); + ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kHashes); + ASSERT_EQ(resp.status(), error::OK); + EXPECT_EQ(util::BigEndian64FromBytes(reinterpret_cast(resp.hashes().db_hash().data())), + 11717402061570123096ULL); + EXPECT_EQ(resp.hashes().commit_idx(), 2); + EXPECT_EQ(util::BigEndian64FromBytes(reinterpret_cast(resp.hashes().commit_hash_chain().data())), + 9806922570174040741ULL); + } +} + +} // namespace svr2::core diff --git a/enclave/db/db.cc b/enclave/db/db.cc new file mode 100644 index 0000000..5c300df --- /dev/null +++ b/enclave/db/db.cc @@ -0,0 +1,27 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include "db/db.h" +#include "db/db2.h" +#include "db/db3.h" + +#include + +namespace svr2::db { + +std::unique_ptr DB::New(enclaveconfig::DatabaseVersion version) { + std::unique_ptr out; + switch (version) { + case enclaveconfig::DATABASE_VERSION_SVR2: + out.reset(new db::DB2()); + break; + case enclaveconfig::DATABASE_VERSION_SVR3: + out.reset(new db::DB3()); + break; + default: + return nullptr; + } + return out; +} + +} // namespace svr2::db diff --git a/enclave/db/db.h b/enclave/db/db.h new file mode 100644 index 0000000..1ca769a --- /dev/null +++ b/enclave/db/db.h @@ -0,0 +1,122 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_DB_DB_H__ +#define __SVR2_DB_DB_H__ + +#include +#include +#include + +#include "proto/error.pb.h" +#include "proto/e2e.pb.h" +#include "proto/msgs.pb.h" +#include "sip/hasher.h" +#include "context/context.h" +#include "util/log.h" + +namespace svr2::db { + +// DB provides a generic interface for databases, which can be used by both +// SVR2 (db2.*) and SVR3 (db3.*). These two databases take in different +// requests and return different responses, which are packaged in the +// DB::Protocol interface and implemented per-database. +// +// A database uses three objects during its lifecycle: +// - Request: a protobuf created and provided by a (remote) client +// - Log: generated from a `Request` and contains the operation to be performed +// - Respose: returned to the (remote) client detailing the output of the operation +// +// In many cases, the Request and Log will be similar, and often the Request +// is simply embedded into the Log. However, the Log generally contains a +// few other key pieces of information: +// - the database key associated with the request/authenticated_id, if there is one +// - any information (entropy, timestamps, etc) which could differ if recomputed +// across different replicas. +// +// Generally, the lifecycle of a request is: +// - the Request is received by one replica +// - that replica uses it to generate a Log +// - that Log is submitted to Raft for ordering and persistence +// - Raft commits the Log +// - the Log is then applied to the database via the Run method +// - the Run method generates a Response +// - on the replica that received the Request, the Response is returned to the client +class DB { + public: + DELETE_COPY_AND_ASSIGN(DB); + DB() {} + virtual ~DB() {} + + // Returns a database based on the passed-in version number. + static std::unique_ptr New(enclaveconfig::DatabaseVersion version); + + typedef google::protobuf::MessageLite Request; + typedef google::protobuf::MessageLite Log; + typedef google::protobuf::MessageLite Response; + + // Protocol encapsulates typing requests and responses for clients. + class Protocol { + public: + // RequestPB creates a new request protobuf in the scope of `ctx` + virtual Request* RequestPB(context::Context* ctx) const = 0; + // LogPB creates a new log protobuf in the scope of `ctx` + virtual Log* LogPB(context::Context* ctx) const = 0; + // Given a request, creates a log. Note that this potentially std::move's + // the request into the log, so care should be taken to not use the request + // after calling LogPBFromRequest. + virtual std::pair LogPBFromRequest( + context::Context* ctx, + Request&& request, + const std::string& authenticated_id) const = 0; + // LogKey returns the database key associated with the given request proto. + virtual const std::string& LogKey(const Log& r) const = 0; + // Validate that a log has the right shape, size, etc. + virtual error::Error ValidateClientLog(const Log& log) const = 0; + // Returns the maximum size of a database row when serialized. + virtual size_t MaxRowSerializedSize() const = 0; + }; + // P() returns a pointer to a _static_ Protocol object, + // which will outlast the DB object. + virtual const Protocol* P() const = 0; + + // Run a client log request and yield a response. + // The client log should already have been checked with ValidateClientLog; + // failing to do so will CHECK-fail. + // It's assumed that validation happens on Raft log insert, so that + // outputs from the Raft log are already validated. + // + // Output response is valid within the passed-in context. + virtual Response* Run(context::Context* ctx, const Log& log) = 0; + + // Get rows from this database in range (exclusive_start, ...], returning + // no more than [size] rows. If it returns <[size] rows, the end of the database + // has been reached. Pass in empty string to start with the first key in + // the database. Returns the key of the largest returned row. + virtual std::pair RowsAsProtos( + context::Context* ctx, + const std::string& exclusive_start, + size_t size, + google::protobuf::RepeatedPtrField* out) const = 0; + // Update this database using the given database row states. + // This will return an error if any of the DatabaseRowStates contain + // rows that already exist within the database. Rows must be lexigraphically + // larger than any existing row in the database. Returns the row key + // of the last row inserted into the database, on success. + virtual std::pair LoadRowsFromProtos( + context::Context* ctx, + const google::protobuf::RepeatedPtrField& rows) = 0; + + // Compute a hash of the entire database. This is not designed to + // be useful for security-focussed integrity checking, but should be + // sufficient to verify that replicated data matches up between source + // and destination. + virtual std::array Hash(context::Context* ctx) const = 0; + + // Get the number of backups stored in the database + virtual size_t row_count() const = 0; +}; + +} // namespace svr2::db + +#endif // __SVR2_DB_DB_H__ diff --git a/enclave/db/db2.cc b/enclave/db/db2.cc new file mode 100644 index 0000000..82b4181 --- /dev/null +++ b/enclave/db/db2.cc @@ -0,0 +1,342 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include "db/db2.h" + +#include +#include + +#include "util/log.h" +#include "util/bytes.h" +#include "util/hex.h" +#include "util/constant.h" +#include "util/endian.h" +#include "context/context.h" +#include "metrics/metrics.h" +#include "proto/clientlog.pb.h" + +namespace svr2::db { + +template +static void CopyArrayToString(const T& array, std::string* out) { + CHECK(array.size() == 0 || sizeof(array[0]) == 1); + out->resize(array.size()); + std::copy(array.cbegin(), array.cend(), out->begin()); +} + +static size_t SMALL_BYTES_FIELD_EXTRA_PROTO_METADATA = 2; +static size_t U16_AS_VARINT_MAX_SIZE = 3; +size_t DB2::Protocol::MaxRowSerializedSize() const { + return + BACKUP_ID_SIZE + SMALL_BYTES_FIELD_EXTRA_PROTO_METADATA + + MAX_DATA_SIZE + SMALL_BYTES_FIELD_EXTRA_PROTO_METADATA + + PIN_SIZE + SMALL_BYTES_FIELD_EXTRA_PROTO_METADATA + + U16_AS_VARINT_MAX_SIZE; // max bytes for TRIES +} + +DB::Request* DB2::Protocol::RequestPB(context::Context* ctx) const { + return ctx->Protobuf(); +} + +DB::Log* DB2::Protocol::LogPB(context::Context* ctx) const { + return ctx->Protobuf(); +} + +std::pair DB2::Protocol::LogPBFromRequest( + context::Context* ctx, + Request&& request, + const std::string& authenticated_id) const { + auto r = dynamic_cast(&request); + if (r == nullptr) { + return std::make_pair(nullptr, COUNTED_ERROR(DB2_InvalidRequestType)); + } + auto log = ctx->Protobuf(); + if (authenticated_id.size() != BACKUP_ID_SIZE) { + return std::make_pair(nullptr, COUNTED_ERROR(DB2_ClientBackupIDSize)); + } + log->set_backup_id(authenticated_id); + *log->mutable_req() = std::move(*r); + return std::make_pair(log, error::OK); +} + +const std::string& DB2::Protocol::LogKey(const DB::Log& req) const { + auto r = dynamic_cast(&req); + CHECK(r != nullptr); + return r->backup_id(); +} + +error::Error DB2::Protocol::ValidateClientLog(const DB::Log& req_pb) const { + auto log = dynamic_cast(&req_pb); + if (log == nullptr) { return COUNTED_ERROR(DB2_InvalidRequestType); } + auto req = log->req(); + + if (log->backup_id().size() != BACKUP_ID_SIZE) { return COUNTED_ERROR(DB2_ClientBackupIDSize); } + switch (req.inner_case()) { + case client::Request::kBackup: { + auto r = req.backup(); + if (r.pin().size() != PIN_SIZE) { return COUNTED_ERROR(DB2_ClientPinSize); } + if (r.data().size() > MAX_DATA_SIZE) { return COUNTED_ERROR(DB2_ClientDataSize); } + if (r.data().size() < MIN_DATA_SIZE) { return COUNTED_ERROR(DB2_ClientDataSize); } + if (r.max_tries() > MAX_ALLOWED_MAX_TRIES) { return COUNTED_ERROR(DB2_ClientTriesTooHigh); } + if (r.max_tries() < MIN_ALLOWED_MAX_TRIES) { return COUNTED_ERROR(DB2_ClientTriesZero); } + } break; + case client::Request::kRestore: { + auto r = req.restore(); + if (r.pin().size() != PIN_SIZE) { return COUNTED_ERROR(DB2_ClientPinSize); } + } break; + case client::Request::kDelete: { + auto r = req.delete_(); + } break; + case client::Request::kExpose: { + auto r = req.expose(); + if (r.data().size() > MAX_DATA_SIZE) { return COUNTED_ERROR(DB2_ClientDataSize); } + if (r.data().size() < MIN_DATA_SIZE) { return COUNTED_ERROR(DB2_ClientDataSize); } + } break; + default: + return COUNTED_ERROR(DB2_ClientRequestCase); + } + return error::OK; +} + +const DB::Protocol* DB2::P() const { + static DB2::Protocol rr; + return &rr; +} + +DB::Response* DB2::Run(context::Context* ctx, const DB::Log& log_pb) { + // We CHECK here because this should have already been validated when it + // was added to the Raft log. + MEASURE_CPU(ctx, cpu_db_client_request); + CHECK(error::OK == P()->ValidateClientLog(log_pb)); + auto log = reinterpret_cast(log_pb); // dynamic_cast checked in ValidateClientLog. + BackupID id; + CHECK(log.backup_id().size() == id.size()); + std::copy(log.backup_id().begin(), log.backup_id().end(), id.begin()); + auto resp = ctx->Protobuf(); + switch (log.req().inner_case()) { + case client::Request::kBackup: + Backup(id, log.req().backup(), resp->mutable_backup()); + break; + case client::Request::kRestore: + Restore(id, log.req().restore(), resp->mutable_restore()); + break; + case client::Request::kDelete: + Delete(id, log.req().delete_(), resp->mutable_delete_()); + break; + case client::Request::kExpose: + Expose(id, log.req().expose(), resp->mutable_expose()); + break; + default: + LOG(WARNING) << "unsupported request case, returning empty response"; + break; + } + return resp; +} + +void DB2::Row::Clear(e2e::DB2RowState::State s) { + memset(data.begin(), 0, data.size()); + memset(pin.begin(), 0, pin.size()); + tries = 0; + data_size = 0; + state = s; +} + +void DB2::Backup(const BackupID& id, const client::BackupRequest& req, client::BackupResponse* resp) { + std::map, Row>::iterator find = rows_.find(id); + if (find == rows_.end()) { + auto e = rows_.emplace( + std::piecewise_construct, + std::forward_as_tuple(std::move(id)), + std::forward_as_tuple()); + find = e.first; + GAUGE(db, rows)->Set(rows_.size()); + } + Row* row = &find->second; + row->Clear(e2e::DB2RowState::POPULATED); + std::copy(req.data().begin(), req.data().end(), row->data.begin()); + row->data_size = req.data().size(); + row->tries = req.max_tries(); + std::copy(req.pin().begin(), req.pin().end(), row->pin.begin()); + resp->set_status(client::BackupResponse::OK); +} + +void DB2::Restore(const BackupID& id, const client::RestoreRequest& req, client::RestoreResponse* resp) { + auto find = rows_.find(id); + if (find == rows_.end() || find->second.state != e2e::DB2RowState::AVAILABLE) { + resp->set_status(client::RestoreResponse::MISSING); + return; + } + Row* row = &find->second; + if (util::ConstantTimeEquals(req.pin(), row->pin)) { + resp->set_status(client::RestoreResponse::OK); + resp->set_tries(row->tries); + *resp->mutable_data() = std::string(row->data.begin(), row->data.begin() + row->data_size); + return; + } + if (--row->tries == 0) { + // We Clear before erasing because erasing just removes the entry from the log, and + // we want to actually zero out the secret wherever it is in memory. + row->Clear(e2e::DB2RowState::UNINITIATED); + rows_.erase(find); + resp->set_status(client::RestoreResponse::MISSING); + GAUGE(db, rows)->Set(rows_.size()); + return; + } + resp->set_status(client::RestoreResponse::PIN_MISMATCH); + resp->set_tries(row->tries); +} + +void DB2::Delete(const BackupID& id, const client::DeleteRequest& req, client::DeleteResponse* resp) { + auto find = rows_.find(id); + if (find == rows_.end()) { return; } + // We Clear before erasing because erasing just removes the entry from the log, and + // we want to actually zero out the secret wherever it is in memory. + find->second.Clear(e2e::DB2RowState::UNINITIATED); + rows_.erase(find); + GAUGE(db, rows)->Set(rows_.size()); +} + +void DB2::Expose(const BackupID& id, const client::ExposeRequest& req, client::ExposeResponse* resp) { + // Expose provides a 2-phase commit of backups, to avoid client backup + // retries from allowing server operators infinite guesses against the pin. + // Without Expose, the following attack is possible: + // 1. client sends BackupRequest + // 2. server processes BackupRequest + // 3. server operator drops connection to client before BackupResponse is sent + // 4. server operator makes max_tries guesses against backup + // 5. client retries BackupRequest (goto 1) + // + // The Expose proto must contain the secret to make sure that only someone + // that already knows the secret (IE: the client) can expose the backup for + // restores. Otherwise, the following attack is possible: + // 1. client sends BackupRequest + // 2. server processes BackupRequest + // 3. server operator drops connection to client before BackupResponse is sent + // 4. server operator sends ExposeRequest to enclave, which processes it + // 5. server operator makes max_tries guesses against backup + // 6. client retries BackupRequest (goto 1) + auto find = rows_.find(id); + if (find == rows_.end()) { + resp->set_status(client::ExposeResponse::ERROR); + return; + } + Row* row = &find->second; + if (!util::ConstantTimeEqualsPrefix(row->data, req.data(), row->data_size)) { + resp->set_status(client::ExposeResponse::ERROR); + return; + } + switch (row->state) { + case e2e::DB2RowState::POPULATED: + case e2e::DB2RowState::AVAILABLE: + row->state = e2e::DB2RowState::AVAILABLE; + resp->set_status(client::ExposeResponse::OK); + return; + default: + resp->set_status(client::ExposeResponse::ERROR); + return; + } +} + +std::pair DB2::RowsAsProtos(context::Context* ctx, const std::string& exclusive_start, size_t size, google::protobuf::RepeatedPtrField* out) const { + MEASURE_CPU(ctx, cpu_db_repl_send); + auto iter = rows_.begin(); + if (!exclusive_start.empty()) { + auto [id, err] = BackupIDFromString(exclusive_start); + if (err != error::OK) { + return std::make_pair("", err); + } + iter = rows_.upper_bound(id); + } + auto row = ctx->Protobuf(); + std::string last_id; + for (size_t i = 0; i < size && iter != rows_.end(); i++, ++iter) { + row->Clear(); + CopyArrayToString(iter->first, row->mutable_backup_id()); + CopyArrayToString(iter->second.data, row->mutable_data()); + row->mutable_data()->resize(iter->second.data_size); + CopyArrayToString(iter->second.pin, row->mutable_pin()); + row->set_tries(iter->second.tries); + row->set_state(iter->second.state); + if (!row->SerializeToString(out->Add())) { + return std::make_pair("", COUNTED_ERROR(DB2_ReplicationInvalidRow)); + } + last_id = row->backup_id(); + } + LOG(DEBUG) << "DB sending rows in (" << util::PrefixToHex(exclusive_start, 8) << ", " << util::PrefixToHex(last_id, 8) << "]"; + return std::make_pair(last_id, error::OK); +} + +DB2::Row::Row() : state(e2e::DB2RowState::UNINITIATED), tries(0), data_size(0), data{0}, pin{0} {} + +std::pair DB2::LoadRowsFromProtos(context::Context* ctx, const google::protobuf::RepeatedPtrField& rows) { + MEASURE_CPU(ctx, cpu_db_repl_recv); + CHECK(rows.size()); + size_t initial_rows = rows_.size(); + auto row = ctx->Protobuf(); + for (int i = 0; i < rows.size(); i++) { + row->Clear(); + if (!row->ParseFromString(rows.Get(i))) { + return std::make_pair("", COUNTED_ERROR(DB2_ReplicationInvalidRow)); + } + if (row->tries() > MAX_ALLOWED_MAX_TRIES || + row->pin().size() != PIN_SIZE || + row->data().size() < MIN_DATA_SIZE || + row->data().size() > MAX_DATA_SIZE) { + return std::make_pair("", COUNTED_ERROR(DB2_ReplicationInvalidRow)); + } + auto [key, err] = BackupIDFromString(row->backup_id()); + if (err != error::OK) { + return std::make_pair("", err); + } + if (rows_.size() && key <= rows_.rbegin()->first) { + return std::make_pair("", COUNTED_ERROR(DB2_ReplicationOutOfOrder)); + } + + Row r; + r.state = row->state(); + std::copy(row->pin().begin(), row->pin().end(), r.pin.begin()); + std::copy(row->data().begin(), row->data().end(), r.data.begin()); + r.data_size = row->data().size(); + r.tries = row->tries(); + rows_.emplace_hint(rows_.end(), key, std::move(r)); + GAUGE(db, rows)->Set(rows_.size()); + } + if (rows_.size() != initial_rows + rows.size()) { + // This ensures that we didn't accidentally attempt to load rows that + // already exist within the DB. + return std::make_pair("", COUNTED_ERROR(DB2_LoadedRowsAlreadyInDB)); + } + return std::make_pair(row->backup_id(), error::OK); +} + +std::pair DB2::BackupIDFromString(const std::string& s) { + DB2::BackupID out; + if (s.size() != BACKUP_ID_SIZE) { + return std::make_pair(std::move(out), COUNTED_ERROR(DB2_BackupIDSize)); + } + std::copy(s.begin(), s.end(), out.data()); + return std::make_pair(std::move(out), error::OK); +} + +std::array DB2::Hash(context::Context* ctx) const { + MEASURE_CPU(ctx, cpu_db_hash); + crypto_hash_sha256_state sha; + crypto_hash_sha256_init(&sha); + uint8_t num[8]; + util::BigEndian64Bytes(rows_.size(), num); + crypto_hash_sha256_update(&sha, num, sizeof(num)); + for (auto iter = rows_.cbegin(); iter != rows_.cend(); ++iter) { + util::BigEndian64Bytes(iter->second.state, num); + crypto_hash_sha256_update(&sha, num, sizeof(num)); + crypto_hash_sha256_update(&sha, iter->first.data(), iter->first.size()); + util::BigEndian64Bytes(iter->second.tries, num); + crypto_hash_sha256_update(&sha, num, sizeof(num)); + crypto_hash_sha256_update(&sha, iter->second.data.data(), iter->second.data_size); + crypto_hash_sha256_update(&sha, iter->second.pin.data(), iter->second.pin.size()); + } + std::array out; + crypto_hash_sha256_final(&sha, out.data()); + return out; +} + +} // namespace svr2::db diff --git a/enclave/db/db2.h b/enclave/db/db2.h new file mode 100644 index 0000000..ae34a8f --- /dev/null +++ b/enclave/db/db2.h @@ -0,0 +1,113 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_DB_DB2_H__ +#define __SVR2_DB_DB2_H__ + +#include +#include +#include "proto/error.pb.h" +#include "proto/e2e.pb.h" +#include "sip/hasher.h" +#include "context/context.h" +#include "util/log.h" +#include "db/db.h" +#include "proto/client.pb.h" + +namespace svr2::db { + +// DB2 implements the DB interface for SVR2. +// DB is a database meant to be driven by a Raft log. +// Raft stores an ordered, consistent list of committed client::Request requests. +// This DB executes those requests as CRUD operations on an underlying ordered map, +// and returns their respective responses. +class DB2 : public DB { + public: + DELETE_COPY_AND_ASSIGN(DB2); + DB2() {} + virtual ~DB2() {} + + class Protocol : public DB::Protocol { + public: + virtual Request* RequestPB(context::Context* ctx) const; + virtual Log* LogPB(context::Context* ctx) const; + virtual std::pair LogPBFromRequest( + context::Context* ctx, + Request&& request, + const std::string& authenticated_id) const; + virtual const std::string& LogKey(const Log& r) const; + virtual error::Error ValidateClientLog(const Log& log) const; + virtual size_t MaxRowSerializedSize() const; + }; + virtual const DB::Protocol* P() const; + + // Run a client log request and yield a response. + // The client log should already have been checked with ValidateClientLog; + // failing to do so will CHECK-fail. + // It's assumed that validation happens on Raft log insert, so that + // outputs from the Raft log are already validated. + // + // Output response is valid within the passed-in context. + virtual Response* Run(context::Context* ctx, const Log& request); + + // Limits on sizes/etc for validation. + static const size_t BACKUP_ID_SIZE = 16; + static const size_t MIN_DATA_SIZE = 16; + static const size_t MAX_DATA_SIZE = 48; + static const size_t PIN_SIZE = 32; + static const uint16_t MAX_ALLOWED_MAX_TRIES = 255; + static const uint16_t MIN_ALLOWED_MAX_TRIES = 1; + + // Get rows from this database in range (exclusive_start, ...], returning + // no more than [size] rows. If it returns <[size] rows, the end of the database + // has been reached. Pass in DB::Beginning to start with the first key in + // the database. + virtual std::pair RowsAsProtos(context::Context* ctx, const std::string& exclusive_start, size_t size, google::protobuf::RepeatedPtrField* out) const; + // Update this database using the given database row states. + // This will return an error if any of the DB2RowStates contain + // rows that already exist within the database. Rows must be lexigraphically + // larger than any existing row in the database. Returns the row key + // of the last row inserted into the database, on success. + virtual std::pair LoadRowsFromProtos(context::Context* ctx, const google::protobuf::RepeatedPtrField& rows); + + // Compute a hash of the entire database. This is not designed to + // be useful for security-focussed integrity checking, but should be + // sufficient to verify that replicated data matches up between source + // and destination. + virtual std::array Hash(context::Context* ctx) const; + + // Get the number of backups stored in the database + virtual size_t row_count() const { return rows_.size(); } + private: + typedef std::array BackupID; + + static std::pair BackupIDFromString(const std::string& s); + struct Row { + Row(); + e2e::DB2RowState::State state; + uint8_t tries; + uint8_t data_size; // should be MIN_DATA_SIZE <= data_size <= MAX_DATA_SIZE, or 0 if unset + // We use std::array here to avoid lots of extra heap allocations. + // We store slightly more data than necessary if client data is + // smaller than MAX_DATA_SIZE, but we make up for it in at least + // three 64-bit pointers if these were std::string. + std::array data; + std::array pin; + + void Clear(e2e::DB2RowState::State s); + }; + + // Execute each of the three request types. + void Backup(const BackupID& id, const client::BackupRequest& request, client::BackupResponse* resp); + void Restore(const BackupID& id, const client::RestoreRequest& request, client::RestoreResponse* resp); + void Delete(const BackupID& id, const client::DeleteRequest& request, client::DeleteResponse* resp); + void Expose(const BackupID& id, const client::ExposeRequest& request, client::ExposeResponse* resp); + // We use std::map over std::unordered_map because order matters to us. + // We need a consistently ordered keyspace for data transfers between + // replicas. + std::map rows_; +}; + +} // namespace svr2::db + +#endif // __SVR2_DB_DB2_H__ diff --git a/enclave/db/db3.cc b/enclave/db/db3.cc new file mode 100644 index 0000000..e843ef5 --- /dev/null +++ b/enclave/db/db3.cc @@ -0,0 +1,298 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include "db/db3.h" + +#include +#include +#include + +#include "util/log.h" +#include "util/bytes.h" +#include "util/hex.h" +#include "util/constant.h" +#include "util/endian.h" +#include "context/context.h" +#include "metrics/metrics.h" +#include "proto/clientlog.pb.h" + +namespace svr2::db { + +DB::Request* DB3::Protocol::RequestPB(context::Context* ctx) const { + return ctx->Protobuf(); +} + +DB::Log* DB3::Protocol::LogPB(context::Context* ctx) const { + return ctx->Protobuf(); +} + +std::pair DB3::Protocol::LogPBFromRequest( + context::Context* ctx, + Request&& request, + const std::string& authenticated_id) const { + auto r = dynamic_cast(&request); + if (r == nullptr) { + return std::make_pair(nullptr, COUNTED_ERROR(DB3_RequestInvalid)); + } + if (authenticated_id.size() != BACKUP_ID_SIZE) { + return std::make_pair(nullptr, COUNTED_ERROR(DB3_BackupIDSize)); + } + auto log = ctx->Protobuf(); + log->set_backup_id(authenticated_id); + *log->mutable_req() = std::move(*r); + if (log->req().inner_case() == client::Request3::kCreate) { + auto [priv, pub] = NewKeys(); + log->set_create_privkey(util::ByteArrayToString(priv)); + log->set_create_pubkey(util::ByteArrayToString(pub)); + } + return std::make_pair(log, error::OK); +} + +const std::string& DB3::Protocol::LogKey(const DB::Log& req) const { + auto r = dynamic_cast(&req); + CHECK(r != nullptr); + return r->backup_id(); +} + +error::Error DB3::Protocol::ValidateClientLog(const DB::Log& log_pb) const { + auto log = dynamic_cast(&log_pb); + if (log == nullptr) { return COUNTED_ERROR(DB3_RequestInvalid); } + if (log->backup_id().size() != BACKUP_ID_SIZE) { return COUNTED_ERROR(DB3_BackupIDSize); } + switch (log->req().inner_case()) { + case client::Request3::kCreate: { + auto r = log->req().create(); + if (r.max_tries() < 1 || r.max_tries() > 255) { return COUNTED_ERROR(DB3_MaxTriesOutOfRange); } + if (r.blinded_element().size() != ELEMENT_SIZE) { return COUNTED_ERROR(DB3_BlindedElementSize); } + if (log->create_privkey().size() != sizeof(PrivateKey)) { return COUNTED_ERROR(DB3_LogPrivateKeyInvalid); } + if (log->create_pubkey().size() != sizeof(PublicKey)) { return COUNTED_ERROR(DB3_LogPublicKeyInvalid); } + } break; + case client::Request3::kEvaluate: { + auto r = log->req().evaluate(); + if (r.blinded_element().size() != ELEMENT_SIZE) { return COUNTED_ERROR(DB3_BlindedElementSize); } + } break; + case client::Request3::kRemove: { + // nothing to do + } break; + default: + return COUNTED_ERROR(DB3_ToplevelRequestType); + } + return error::OK; +} + +const DB::Protocol* DB3::P() const { + static DB3::Protocol rr; + return &rr; +} + +size_t DB3::Protocol::MaxRowSerializedSize() const { + const size_t PROTOBUF_SMALL_STRING_EXTRA = 2; // additional bytes for serializing string + const size_t PROTOBUF_SMALL_INT = 2; // bytes for serializing a small integer + return BACKUP_ID_SIZE + PROTOBUF_SMALL_STRING_EXTRA + // backup ID + SCALAR_SIZE + PROTOBUF_SMALL_STRING_EXTRA + // priv key + PROTOBUF_SMALL_INT; // tries +} + +DB::Response* DB3::Run(context::Context* ctx, const DB::Log& log_pb) { + MEASURE_CPU(ctx, cpu_db_client_request); + CHECK(P()->ValidateClientLog(log_pb) == error::OK); + auto log = dynamic_cast(&log_pb); + CHECK(log != nullptr); + auto out = ctx->Protobuf(); + auto [id, err] = util::StringToByteArray(log->backup_id()); + CHECK(err == error::OK); + switch (log->req().inner_case()) { + case client::Request3::kCreate: { + Create(ctx, id, log->create_privkey(), log->create_pubkey(), log->req().create(), out->mutable_create()); + } break; + case client::Request3::kEvaluate: { + Evaluate(ctx, id, log->req().evaluate(), out->mutable_evaluate()); + } break; + case client::Request3::kRemove: { + Remove(ctx, id, log->req().remove(), out->mutable_remove()); + } break; + default: CHECK(nullptr == "should never reach here, client log already validated"); + } + return out; +} + +std::pair DB3::RowsAsProtos(context::Context* ctx, const std::string& exclusive_start, size_t size, google::protobuf::RepeatedPtrField* out) const { + MEASURE_CPU(ctx, cpu_db_repl_send); + auto iter = rows_.begin(); + if (!exclusive_start.empty()) { + auto [id, err] = util::StringToByteArray(exclusive_start); + if (err != error::OK) { + return std::make_pair("", err); + } + iter = rows_.upper_bound(id); + } + auto row = ctx->Protobuf(); + for (size_t i = 0; i < size && iter != rows_.end(); i++, ++iter) { + row->Clear(); + row->set_backup_id(util::ByteArrayToString(iter->first)); + row->set_priv(util::ByteArrayToString(iter->second.priv)); + row->set_tries(iter->second.tries); + if (!row->SerializeToString(out->Add())) { + return std::make_pair("", COUNTED_ERROR(DB3_ReplicationInvalidRow)); + } + } + LOG(DEBUG) << "DB sending rows in (" << util::PrefixToHex(exclusive_start, 8) << ", " << util::PrefixToHex(row->backup_id(), 8) << "]"; + return std::make_pair(row->backup_id(), error::OK); +} + +std::pair DB3::LoadRowsFromProtos(context::Context* ctx, const google::protobuf::RepeatedPtrField& rows) { + MEASURE_CPU(ctx, cpu_db_repl_recv); + size_t initial_rows = rows_.size(); + auto row = ctx->Protobuf(); + for (int i = 0; i < rows.size(); i++) { + row->Clear(); + if (!row->ParseFromString(rows.Get(i))) { + return std::make_pair("", COUNTED_ERROR(DB3_ReplicationInvalidRow)); + } + if (row->tries() > MAX_ALLOWED_MAX_TRIES || + row->tries() < MIN_ALLOWED_MAX_TRIES) { + return std::make_pair("", COUNTED_ERROR(DB3_ReplicationInvalidRow)); + } + auto [key, err1] = util::StringToByteArray(row->backup_id()); + if (err1 != error::OK) { + return std::make_pair("", err1); + } + if (rows_.size() && key <= rows_.rbegin()->first) { + return std::make_pair("", COUNTED_ERROR(DB3_ReplicationOutOfOrder)); + } + auto [priv, err2] = util::StringToByteArray(row->priv()); + if (err2 != error::OK) { + return std::make_pair("", err2); + } + + Row r; + r.tries = row->tries(); + r.priv = priv; + rows_.emplace_hint(rows_.end(), key, std::move(r)); + GAUGE(db, rows)->Set(rows_.size()); + } + if (rows_.size() != initial_rows + rows.size()) { + // This ensures that we didn't accidentally attempt to load rows that + // already exist within the DB. + return std::make_pair("", COUNTED_ERROR(DB3_LoadedRowsAlreadyInDB)); + } + return std::make_pair(row->backup_id(), error::OK); +} + +std::array DB3::Hash(context::Context* ctx) const { + MEASURE_CPU(ctx, cpu_db_hash); + crypto_hash_sha256_state sha; + crypto_hash_sha256_init(&sha); + uint8_t num[8]; + util::BigEndian64Bytes(rows_.size(), num); + crypto_hash_sha256_update(&sha, num, sizeof(num)); + for (auto iter = rows_.cbegin(); iter != rows_.cend(); ++iter) { + crypto_hash_sha256_update(&sha, iter->first.data(), iter->first.size()); + util::BigEndian64Bytes(iter->second.tries, num); + crypto_hash_sha256_update(&sha, num, sizeof(num)); + crypto_hash_sha256_update(&sha, iter->second.priv.data(), iter->second.priv.size()); + } + std::array out; + crypto_hash_sha256_final(&sha, out.data()); + return out; +} + +std::pair DB3::BlindEvaluate(const DB3::PrivateKey& key, const DB3::Element& blinded_element) { + Element out{0}; + int ret = 0; + if (0 != (ret = crypto_scalarmult_ristretto255(out.data(), key.data(), blinded_element.data()))) { + LOG(WARNING) << "crypto_scalarmult_ristretto255 error: " << ret; + return std::make_pair(out, COUNTED_ERROR(DB3_ScalarMultFailure)); + } + return std::make_pair(out, error::OK); +} + +std::pair DB3::Protocol::NewKeys() { + PrivateKey priv{0}; + PublicKey pub{0}; + crypto_core_ristretto255_scalar_random(priv.data()); + // This will only return non-zero if `priv == 0`, which should never happen. + // TODO: Consider using either a protocol specific or a server specific base point. + CHECK(0 == crypto_scalarmult_ristretto255_base(pub.data(), priv.data())); + return std::make_pair(priv, pub); +} + +void DB3::Create( + context::Context* ctx, + const DB3::BackupID& id, + const std::string& privkey, + const std::string& pubkey, + const client::CreateRequest& req, + client::CreateResponse* resp) { + auto [elt, err1] = util::StringToByteArray(req.blinded_element()); + if (err1 != error::OK) { + resp->set_status(client::CreateResponse::INVALID_REQUEST); + return; + } + auto [priv, err2] = util::StringToByteArray(privkey); + if (err2 != error::OK) { + resp->set_status(client::CreateResponse::ERROR); + return; + } + auto [pub, err3] = util::StringToByteArray(pubkey); + if (err3 != error::OK) { + resp->set_status(client::CreateResponse::ERROR); + return; + } + auto [evaluated, err4] = BlindEvaluate(priv, elt); + if (err4 != error::OK) { + resp->set_status(client::CreateResponse::ERROR); + return; + } + rows_[id] = { + .priv = priv, + .tries = (uint8_t) req.max_tries(), + }; + GAUGE(db, rows)->Set(rows_.size()); + resp->set_evaluated_element(util::ByteArrayToString(evaluated)); + resp->set_public_key(util::ByteArrayToString(pub)); + resp->set_status(client::CreateResponse::OK); +} + +void DB3::Evaluate( + context::Context* ctx, + const DB3::BackupID& id, + const client::EvaluateRequest& req, + client::EvaluateResponse* resp) { + auto [elt, err1] = util::StringToByteArray(req.blinded_element()); + if (err1 != error::OK) { + resp->set_status(client::EvaluateResponse::INVALID_REQUEST); + return; + } + auto find = rows_.find(id); + if (find == rows_.end()) { + resp->set_status(client::EvaluateResponse::MISSING); + return; + } + auto [evaluated, err2] = BlindEvaluate(find->second.priv, elt); + if (err2 != error::OK) { + resp->set_status(client::EvaluateResponse::ERROR); + return; + } + find->second.tries--; + resp->set_tries_remaining(find->second.tries); + if (find->second.tries == 0) { + rows_.erase(find); + GAUGE(db, rows)->Set(rows_.size()); + } + resp->set_evaluated_element(util::ByteArrayToString(evaluated)); + resp->set_status(client::EvaluateResponse::OK); +} + +void DB3::Remove( + context::Context* ctx, + const DB3::BackupID& id, + const client::RemoveRequest& req, + client::RemoveResponse* resp) { + auto find = rows_.find(id); + if (find != rows_.end()) { + rows_.erase(find); + GAUGE(db, rows)->Set(rows_.size()); + } +} + +} // namespace svr2::db diff --git a/enclave/db/db3.h b/enclave/db/db3.h new file mode 100644 index 0000000..ea9fb34 --- /dev/null +++ b/enclave/db/db3.h @@ -0,0 +1,124 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_DB_DB3_H__ +#define __SVR2_DB_DB3_H__ + +#include +#include +#include "proto/error.pb.h" +#include "proto/e2e.pb.h" +#include "proto/msgs.pb.h" +#include "sip/hasher.h" +#include "context/context.h" +#include "util/log.h" +#include "db/db.h" +#include "proto/client3.pb.h" +#include +#include + +namespace svr2::db { + +class DB3 : public DB { + public: + DELETE_COPY_AND_ASSIGN(DB3); + DB3() {} + virtual ~DB3() {} + + static const size_t BACKUP_ID_SIZE = 16; + static const size_t SCALAR_SIZE = crypto_scalarmult_ristretto255_SCALARBYTES; + static const size_t ELEMENT_SIZE = crypto_scalarmult_ristretto255_BYTES; + typedef std::array BackupID; + typedef std::array Scalar; + typedef std::array Element; + typedef Scalar PrivateKey; + typedef Element PublicKey; + static const uint16_t MAX_ALLOWED_MAX_TRIES = 255; + static const uint16_t MIN_ALLOWED_MAX_TRIES = 1; + + // Protocol encapsulates typing requests and responses for clients. + class Protocol : public DB::Protocol { + public: + virtual DB::Request* RequestPB(context::Context* ctx) const; + virtual DB::Log* LogPB(context::Context* ctx) const; + virtual std::pair LogPBFromRequest( + context::Context* ctx, + Request&& request, + const std::string& authenticated_id) const; + virtual const std::string& LogKey(const DB::Log& r) const; + virtual error::Error ValidateClientLog(const DB::Log& log) const; + virtual size_t MaxRowSerializedSize() const; + public_for_test: + static std::pair NewKeys(); + }; + // P() returns a pointer to a _static_ Protocol object, + // which will outlast the DB object. + virtual const DB::Protocol* P() const; + + // Run a client log request and yield a response. + // The client log should already have been checked with ValidateClientLog; + // failing to do so will CHECK-fail. + // It's assumed that validation happens on Raft log insert, so that + // outputs from the Raft log are already validated. + // + // Output response is valid within the passed-in context. + virtual DB::Response* Run(context::Context* ctx, const DB::Log& request); + + // Get rows from this database in range (exclusive_start, ...], returning + // no more than [size] rows. If it returns <[size] rows, the end of the database + // has been reached. Pass in empty string to start with the first key in + // the database. Returns the key of the largest returned row. + virtual std::pair RowsAsProtos( + context::Context* ctx, + const std::string& exclusive_start, + size_t size, + google::protobuf::RepeatedPtrField* out) const; + // Update this database using the given database row states. + // This will return an error if any of the DatabaseRowStates contain + // rows that already exist within the database. Rows must be lexigraphically + // larger than any existing row in the database. Returns the row key + // of the last row inserted into the database, on success. + virtual std::pair LoadRowsFromProtos( + context::Context* ctx, + const google::protobuf::RepeatedPtrField& rows); + + // Compute a hash of the entire database. This is not designed to + // be useful for security-focussed integrity checking, but should be + // sufficient to verify that replicated data matches up between source + // and destination. + virtual std::array Hash(context::Context* ctx) const; + + // Get the number of backups stored in the database + virtual size_t row_count() const { return rows_.size(); } + + private: + static std::pair BlindEvaluate(const PrivateKey& key, const Element& blinded_element); + + struct Row { + PrivateKey priv; + uint8_t tries; + }; + std::map rows_; + + void Create( + context::Context* ctx, + const BackupID& id, + const std::string& privkey, + const std::string& pubkey, + const client::CreateRequest& req, + client::CreateResponse* resp); + void Evaluate( + context::Context* ctx, + const BackupID& id, + const client::EvaluateRequest& req, + client::EvaluateResponse* resp); + void Remove( + context::Context* ctx, + const BackupID& id, + const client::RemoveRequest& req, + client::RemoveResponse* resp); +}; + +} // namespace svr2::db + +#endif // __SVR2_DB_DB3_H__ diff --git a/enclave/db/tests/db2.cc b/enclave/db/tests/db2.cc new file mode 100644 index 0000000..43dfaf4 --- /dev/null +++ b/enclave/db/tests/db2.cc @@ -0,0 +1,247 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +//TESTDEP gtest +//TESTDEP context +//TESTDEP env +//TESTDEP env/test +//TESTDEP env +//TESTDEP util +//TESTDEP metrics +//TESTDEP proto +//TESTDEP protobuf-lite +//TESTDEP libsodium + +#include +#include "db/db2.h" +#include "env/env.h" +#include "util/log.h" +#include "util/endian.h" +#include "proto/client.pb.h" +#include "proto/clientlog.pb.h" +#include + +namespace svr2::db { + +class DB2Test : public ::testing::Test { + protected: + static void SetUpTestCase() { + env::Init(); + } + + context::Context ctx; + DB2 db; +}; + +TEST_F(DB2Test, SingleBackupLifecycle) { + { + client::Log2 log; + auto b = log.mutable_req()->mutable_backup(); + log.set_backup_id("BACKUP7890123456"); + b->set_data("DATA56789012345678901234567890123456789012345678"); + b->set_pin("PIN45678901234567890123456789012"); + b->set_max_tries(2); + + auto resp = dynamic_cast(db.Run(&ctx, log)); + ASSERT_EQ(client::BackupResponse::OK, resp->backup().status()); + } + { + client::Log2 log; + auto b = log.mutable_req()->mutable_expose(); + log.set_backup_id("BACKUP7890123456"); + b->set_data("DATA56789012345678901234567890123456789012345678"); + + auto resp = dynamic_cast(db.Run(&ctx, log)); + ASSERT_EQ(client::ExposeResponse::OK, resp->expose().status()); + } + { + client::Log2 log; + auto r = log.mutable_req()->mutable_restore(); + log.set_backup_id("BACKUP7890123456"); + r->set_pin("PIN45678901234567890123456789012"); + + auto resp = dynamic_cast(db.Run(&ctx, log)); + ASSERT_EQ(client::RestoreResponse::OK, resp->restore().status()); + ASSERT_EQ("DATA56789012345678901234567890123456789012345678", resp->restore().data()); + ASSERT_EQ(2, resp->restore().tries()); + } + { + client::Log2 log; + auto r = log.mutable_req()->mutable_restore(); + log.set_backup_id("BACKUP7890123456"); + r->set_pin("PIN............................2"); + + auto resp = dynamic_cast(db.Run(&ctx, log)); + ASSERT_EQ(client::RestoreResponse::PIN_MISMATCH, resp->restore().status()); + ASSERT_EQ("", resp->restore().data()); + ASSERT_EQ(1, resp->restore().tries()); + } + { + client::Log2 log; + auto r = log.mutable_req()->mutable_restore(); + log.set_backup_id("BACKUP7890123456"); + r->set_pin("PIN............................2"); + + auto resp = dynamic_cast(db.Run(&ctx, log)); + ASSERT_EQ(client::RestoreResponse::MISSING, resp->restore().status()); + } +} + +TEST_F(DB2Test, SmallerData) { + { + client::Log2 log; + auto b = log.mutable_req()->mutable_backup(); + log.set_backup_id("BACKUP7890123456"); + b->set_data("DATA5678901234567890123456789012"); // 32 bytes + b->set_pin("PIN45678901234567890123456789012"); + b->set_max_tries(2); + + auto resp = dynamic_cast(db.Run(&ctx, log)); + ASSERT_EQ(client::BackupResponse::OK, resp->backup().status()); + } + { + client::Log2 log; + auto b = log.mutable_req()->mutable_expose(); + log.set_backup_id("BACKUP7890123456"); + b->set_data("DATA5678901234567890123456789012"); // 32 bytes + + auto resp = dynamic_cast(db.Run(&ctx, log)); + ASSERT_EQ(client::ExposeResponse::OK, resp->expose().status()); + } + { + client::Log2 log; + auto r = log.mutable_req()->mutable_restore(); + log.set_backup_id("BACKUP7890123456"); + r->set_pin("PIN45678901234567890123456789012"); + + auto resp = dynamic_cast(db.Run(&ctx, log)); + ASSERT_EQ(client::RestoreResponse::OK, resp->restore().status()); + ASSERT_EQ("DATA5678901234567890123456789012", resp->restore().data()); + } +} + +TEST_F(DB2Test, Delete) { + { + client::Log2 log; + auto b = log.mutable_req()->mutable_backup(); + log.set_backup_id("BACKUP7890123456"); + b->set_data("DATA5678901234567890123456789012"); // 32 bytes + b->set_pin("PIN45678901234567890123456789012"); + b->set_max_tries(2); + + auto resp = dynamic_cast(db.Run(&ctx, log)); + ASSERT_EQ(client::BackupResponse::OK, resp->backup().status()); + } + { + client::Log2 log; + auto b = log.mutable_req()->mutable_expose(); + log.set_backup_id("BACKUP7890123456"); + b->set_data("DATA5678901234567890123456789012"); // 32 bytes + + auto resp = dynamic_cast(db.Run(&ctx, log)); + ASSERT_EQ(client::ExposeResponse::OK, resp->expose().status()); + } + { + client::Log2 log; + auto d = log.mutable_req()->mutable_delete_(); + log.set_backup_id("BACKUP7890123456"); + auto resp = dynamic_cast(db.Run(&ctx, log)); + } + { + client::Log2 log; + auto r = log.mutable_req()->mutable_restore(); + log.set_backup_id("BACKUP7890123456"); + r->set_pin("PIN45678901234567890123456789012"); + + auto resp = dynamic_cast(db.Run(&ctx, log)); + ASSERT_EQ(client::RestoreResponse::MISSING, resp->restore().status()); + } +} + +TEST_F(DB2Test, MultipleRows) { + std::string backup_id("BACKUP789012345."); + std::string data("DATA567890123456789012345678901."); + for (int i = 0; i < 256; i++) { + backup_id[DB2::BACKUP_ID_SIZE-1] = i; + data[31] = i; + { + client::Log2 log; + auto b = log.mutable_req()->mutable_backup(); + log.set_backup_id(backup_id); + b->set_data(data); // 32 bytes + b->set_pin("PIN45678901234567890123456789012"); + b->set_max_tries(2); + + auto resp = dynamic_cast(db.Run(&ctx, log)); + ASSERT_EQ(client::BackupResponse::OK, resp->backup().status()); + } + { + client::Log2 log; + auto b = log.mutable_req()->mutable_expose(); + log.set_backup_id(backup_id); + b->set_data(data); // 32 bytes + + auto resp = dynamic_cast(db.Run(&ctx, log)); + ASSERT_EQ(client::ExposeResponse::OK, resp->expose().status()); + } + } + for (int i = 0; i < 256; i++) { + client::Log2 log; + auto r = log.mutable_req()->mutable_restore(); + backup_id[DB2::BACKUP_ID_SIZE-1] = i; + data[31] = i; + log.set_backup_id(backup_id); + r->set_pin("PIN45678901234567890123456789012"); + + auto resp = dynamic_cast(db.Run(&ctx, log)); + ASSERT_EQ(client::RestoreResponse::OK, resp->restore().status()); + ASSERT_EQ(data, resp->restore().data()); + } +} + +TEST_F(DB2Test, HashMatch) { + std::string backup_id("BACKUP789012345."); + std::string data("DATA567890123456789012345678901."); + uint64_t hash = 0; + for (int i = 0; i < 256; i++) { + client::Log2 log; + auto b = log.mutable_req()->mutable_backup(); + backup_id[DB2::BACKUP_ID_SIZE-1] = i; + data[31] = i; + log.set_backup_id(backup_id); + b->set_data(data); // 32 bytes + b->set_pin("PIN45678901234567890123456789012"); + b->set_max_tries(2); + + auto resp = dynamic_cast(db.Run(&ctx, log)); + ASSERT_EQ(client::BackupResponse::OK, resp->backup().status()); + uint64_t new_hash = util::BigEndian64FromBytes(db.Hash(&ctx).data()); + ASSERT_NE(hash, new_hash); // hash changes with every database change. + hash = new_hash; + } + ASSERT_EQ(hash, 784802678439774802ULL); +} + +TEST_F(DB2Test, HashMatchBackwards) { + // Make sure that even if we construct the same DB in a different way + // (in this case, by inserting back IDs in reverse of HashMatch), we + // get the same result. + std::string backup_id("BACKUP789012345."); + std::string data("DATA567890123456789012345678901."); + for (int i = 255; i >= 0; i--) { + client::Log2 log; + auto b = log.mutable_req()->mutable_backup(); + backup_id[DB2::BACKUP_ID_SIZE-1] = i; + data[31] = i; + log.set_backup_id(backup_id); + b->set_data(data); // 32 bytes + b->set_pin("PIN45678901234567890123456789012"); + b->set_max_tries(2); + + auto resp = dynamic_cast(db.Run(&ctx, log)); + ASSERT_EQ(client::BackupResponse::OK, resp->backup().status()); + } + ASSERT_EQ(util::BigEndian64FromBytes(db.Hash(&ctx).data()), 784802678439774802ULL); +} + +} // namespace svr2::db diff --git a/enclave/db/tests/db3.cc b/enclave/db/tests/db3.cc new file mode 100644 index 0000000..5a7ba84 --- /dev/null +++ b/enclave/db/tests/db3.cc @@ -0,0 +1,373 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +//TESTDEP gtest +//TESTDEP context +//TESTDEP env +//TESTDEP env/test +//TESTDEP env +//TESTDEP util +//TESTDEP metrics +//TESTDEP proto +//TESTDEP protobuf-lite +//TESTDEP libsodium + +#include +#include "db/db3.h" +#include "env/env.h" +#include "util/log.h" +#include "util/endian.h" +#include "util/bytes.h" +#include "util/hex.h" +#include "proto/client3.pb.h" +#include "proto/clientlog.pb.h" +#include +#include +#include + +namespace svr2::db { + +class DB3Test : public ::testing::Test { + protected: + static void SetUpTestCase() { + env::Init(); + } + + context::Context ctx; + DB3 db; + static std::string backup_id; +}; + +std::string DB3Test::backup_id("BACKUP7890123456"); + +TEST_F(DB3Test, SingleBackupLifecycle) { + std::string blinded_element; + blinded_element.resize(DB3::ELEMENT_SIZE); + crypto_core_ristretto255_random( + reinterpret_cast(blinded_element.data())); + std::string evaluated_element; + int tries = 3; + { + client::Log3 log; + log.set_backup_id(backup_id); + auto b = log.mutable_req()->mutable_create(); + b->set_max_tries(3); + b->set_blinded_element(blinded_element); + auto [priv, pub] = DB3::Protocol::NewKeys(); + log.set_create_privkey(util::ByteArrayToString(priv)); + log.set_create_pubkey(util::ByteArrayToString(pub)); + + auto resp = dynamic_cast(db.Run(&ctx, log)); + auto r = resp->create(); + ASSERT_EQ(client::CreateResponse::OK, r.status()); + evaluated_element = r.evaluated_element(); + } + for (int i = 0; i < tries; i++) { + client::Log3 log; + log.set_backup_id(backup_id); + auto b = log.mutable_req()->mutable_evaluate(); + b->set_blinded_element(blinded_element); + + auto resp = dynamic_cast(db.Run(&ctx, log)); + auto r = resp->evaluate(); + ASSERT_EQ(client::EvaluateResponse::OK, r.status()); + EXPECT_EQ(r.tries_remaining(), tries - i - 1); + EXPECT_EQ(r.evaluated_element(), evaluated_element); + } + { + client::Log3 log; + log.set_backup_id(backup_id); + auto b = log.mutable_req()->mutable_evaluate(); + b->set_blinded_element(blinded_element); + + auto resp = dynamic_cast(db.Run(&ctx, log)); + auto r = resp->evaluate(); + ASSERT_EQ(client::EvaluateResponse::MISSING, r.status()); + } +} + +TEST_F(DB3Test, Remove) { + std::string blinded_element; + blinded_element.resize(DB3::ELEMENT_SIZE); + crypto_core_ristretto255_random( + reinterpret_cast(blinded_element.data())); + std::string evaluated_element; + int tries = 3; + { + client::Log3 log; + log.set_backup_id(backup_id); + auto b = log.mutable_req()->mutable_create(); + b->set_max_tries(3); + b->set_blinded_element(blinded_element); + auto [priv, pub] = DB3::Protocol::NewKeys(); + log.set_create_privkey(util::ByteArrayToString(priv)); + log.set_create_pubkey(util::ByteArrayToString(pub)); + + auto resp = dynamic_cast(db.Run(&ctx, log)); + auto r = resp->create(); + ASSERT_EQ(client::CreateResponse::OK, r.status()); + evaluated_element = r.evaluated_element(); + } + { + client::Log3 log; + log.set_backup_id(backup_id); + auto b = log.mutable_req()->mutable_remove(); + + db.Run(&ctx, log); + } + { + client::Log3 log; + log.set_backup_id(backup_id); + auto b = log.mutable_req()->mutable_evaluate(); + b->set_blinded_element(blinded_element); + + auto resp = dynamic_cast(db.Run(&ctx, log)); + auto r = resp->evaluate(); + ASSERT_EQ(client::EvaluateResponse::MISSING, r.status()); + } +} + +// IETF VOPRF v21 test vectors (https://www.ietf.org/archive/id/draft-irtf-cfrg-voprf-21.html) + +const std::string context_string_prefix{"OPRFV1-"}; +const std::string ciphersuite_identifier{"ristretto255-SHA512"}; + +static const size_t PRIVATE_KEY_SIZE = 32; +static const size_t PUBLIC_KEY_SIZE = 32; +static const size_t SHA512_BLOCK_BYTES = 128; +static const size_t SHA512_OUTPUT_BYTES = 64; + +// https://www.rfc-editor.org/rfc/rfc8017 +std::string I2OSP(uint64_t x, size_t n) { + std::string X; + X.resize(n); + for(size_t i = 0; i < n; ++i) { + X[n-1-i] = x%256; + x /= 256; + } + return X; +} + +/* +def CreateContextString(mode, identifier): + return "OPRFV1-" || I2OSP(mode, 1) || "-" || identifier +*/ +std::string context_string() { + auto mode = I2OSP(0x00, 1); + return context_string_prefix + mode + "-" + ciphersuite_identifier; +} + +std::string sha512_hash(std::string s) { + crypto_hash_sha512_state sha; + crypto_hash_sha512_init(&sha); + crypto_hash_sha512_update(&sha, reinterpret_cast(s.data()), s.size()); + std::array out; + crypto_hash_sha512_final(&sha, out.data()); + return util::ByteArrayToString(out); +} + +std::string strxor(const std::string& lhs, const std::string& rhs) { + CHECK(lhs.size() == rhs.size()); + std::string result; + result.resize(rhs.size()); + for(size_t i = 0; i < lhs.size(); ++i) { + result[i] = lhs[i] ^ rhs[i]; + } + return result; +} + +template +bool is_zero(const std::array& arr) { + bool result = true; + for(size_t i = 0; i < N; ++i) { + result = result && (arr[i] == 0); + } + return result; +} + +// https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-hash-to-curve-12#name-expand_message_xmd +template +std::array ExpandMessageXMD_SHA512(std::string msg, std::string dst) { + auto ell = N / SHA512_OUTPUT_BYTES + ((N%SHA512_OUTPUT_BYTES == 0) ? 0 : 1); + CHECK(ell <= 255); + LOG(DEBUG) << "expand_message_xmd blocks: " << ell; + std::array result{0}; + + auto dst_prime = dst + I2OSP(dst.size(),1); + auto z_pad = I2OSP(0, SHA512_BLOCK_BYTES); + auto l_i_b_str = I2OSP(N,2); + auto msg_prime = z_pad + msg + l_i_b_str + I2OSP(0,1) + dst_prime; + auto b_0 = sha512_hash(msg_prime); + auto b_1 = sha512_hash(b_0 + I2OSP(1,1) + dst_prime); + auto bytes_to_copy = std::min(b_1.size(), N); + std::copy(b_1.data(), b_1.data()+ bytes_to_copy, result.data()); + auto b_last = b_1; + for(size_t i = 2; i <= ell; ++i) { + auto b_next = sha512_hash( + strxor(b_0, b_last) + + I2OSP(i,1) + + dst_prime + ); + auto bytes_to_copy = std::min(SHA512_OUTPUT_BYTES, N - (i-1)*SHA512_OUTPUT_BYTES); + LOG(DEBUG) << "copying " << bytes_to_copy << " bytes"; + std::copy(b_next.data(), b_next.data() + bytes_to_copy, result.data() + (i-1)*SHA512_OUTPUT_BYTES); + b_last = b_next; + } + return result; +} + +std::array HashToScalar(const std::string& data) { + std::string dst = std::string{"HashToScalar-"} + context_string(); + auto uniform_bytes = ExpandMessageXMD_SHA512<64>(data, dst); + std::array s; + // TODO: verify that this interprets numbers in little-endian order + crypto_core_ristretto255_scalar_reduce(s.data(), uniform_bytes.data()); + return s; +} + +std::pair, std::array> +DeriveKeyPair(std::string seed, std::string info) { + std::string derive_input = seed + I2OSP(info.size(),2) + info; + size_t counter = 0; + std::array sk{0}; + std::array pk{0}; + + std::string dst = std::string{"DeriveKeyPair"} + context_string(); + while(is_zero(sk)) { + LOG(DEBUG) << "derive key pair attempt " << counter; + CHECK(counter < 255); + auto uniform_bytes = + ExpandMessageXMD_SHA512<64>(derive_input + I2OSP(counter,1), dst); + crypto_core_ristretto255_scalar_reduce(sk.data(), uniform_bytes.data()); + counter += 1; + } + CHECK(0 == crypto_scalarmult_ristretto255_base(pk.data(), sk.data())); + + return std::make_pair(pk, sk); +} + +std::array HashToGroup(std::string input) { + std::string dst = std::string{"HashToGroup-"} + context_string(); + auto uniform_bytes = ExpandMessageXMD_SHA512<64>(input, dst); + std::array result{}; + crypto_core_ristretto255_from_hash(result.data(), uniform_bytes.data()); + return result; +} + +TEST_F(DB3Test, IETF_A_1_1) { + auto seed = util::HexToBytes("a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3"); + auto key_info = util::HexToBytes("74657374206b6579"); + auto sk_expected = "5ebcea5ee37023ccb9fc2d2019f9d7737be85591ae8652ffa9ef0f4d37063b0e"; + + auto cs = context_string(); + for(size_t i = 0; i < cs.size(); ++i) { + LOG(DEBUG) << " (" << static_cast(cs[i]) << ") " << cs[i] ; + } + LOG(DEBUG) << cs; + + auto [pk, sk] = DeriveKeyPair(seed, key_info); + auto sk_hex = util::BytesToHex(sk.data(), PRIVATE_KEY_SIZE); + EXPECT_EQ(sk_hex, sk_expected); +} + +TEST_F(DB3Test, EXPAND_MESSAGE_XMD_1) { + std::string dst{"QUUX-V01-CS02-with-expander-SHA512-256"}; + std::string msg{"abc"}; + size_t len_in_bytes = 0x80; + auto uniform_bytes = ExpandMessageXMD_SHA512<0x80>(msg, dst); + LOG(DEBUG) << "here"; + auto hex = util::BytesToHex(uniform_bytes.data(), uniform_bytes.size()); + LOG(DEBUG) << hex; + LOG(DEBUG) << "there"; + + EXPECT_EQ(util::BytesToHex(uniform_bytes.data(), uniform_bytes.size()), "7f1dddd13c08b543f2e2037b14cefb255b44c83cc397c1786d975653e36a6b11bdd7732d8b38adb4a0edc26a0cef4bb45217135456e58fbca1703cd6032cb1347ee720b87972d63fbf232587043ed2901bce7f22610c0419751c065922b488431851041310ad659e4b23520e1772ab29dcdeb2002222a363f0c2b1c972b3efe1"); +} + +TEST_F(DB3Test, EXPAND_MESSAGE_XMD_2) { + std::string dst{"QUUX-V01-CS02-with-expander-SHA512-256"}; + std::string msg{"abcdef0123456789"}; + size_t len_in_bytes = 0x20; + auto uniform_bytes = ExpandMessageXMD_SHA512<0x20>(msg, dst); + LOG(DEBUG) << util::BytesToHex(uniform_bytes.data(), uniform_bytes.size()); + + EXPECT_EQ(util::BytesToHex(uniform_bytes.data(), uniform_bytes.size()), "087e45a86e2939ee8b91100af1583c4938e0f5fc6c9db4b107b83346bc967f58"); +} + +TEST_F(DB3Test, IETF_A_1_1_1) { + auto sk = util::HexToBytes("5ebcea5ee37023ccb9fc2d2019f9d7737be85591ae8652ffa9ef0f4d37063b0e"); + auto input = util::HexToBytes("00"); + auto blind = util::HexToBytes("64d37aed22a27f5191de1c1d69fadb899d8862b58eb4220029e036ec4c1f6706"); + auto blinded_element_expected = util::HexToBytes("609a0ae68c15a3cf6903766461307e5c8bb2f95e7e6550e1ffa2dc99e412803c"); + std::string evaluation_element_hex = "7ec6578ae5120958eb2db1745758ff379e77cb64fe77b0b2d8cc917ea0869c7e"; + std::string output_hex = "527759c3d9366f277d8c6020418d96bb393ba2afb20ff90df23fb7708264e2f3ab9135e3bd69955851de4b1f9fe8a0973396719b7912ba9ee8aa7d0b5e24bcf6"; + + // Compute blinded element + std::array blinded_element; + std::array elt = HashToGroup(input); + auto ret = crypto_scalarmult_ristretto255(blinded_element.data(), reinterpret_cast(blind.data()), elt.data()); + EXPECT_EQ(util::BytesToHex(blinded_element.data(), PUBLIC_KEY_SIZE), "609a0ae68c15a3cf6903766461307e5c8bb2f95e7e6550e1ffa2dc99e412803c"); + + // send to server to evaluate + std::string evaluated_element; + int tries = 3; + { + client::Log3 log; + log.set_backup_id(backup_id); + auto b = log.mutable_req()->mutable_create(); + b->set_max_tries(3); + b->set_blinded_element(util::ByteArrayToString(blinded_element)); + std::array pk{}; + CHECK(0 == crypto_scalarmult_ristretto255_base(pk.data(), reinterpret_cast(sk.data()))); + log.set_create_privkey(sk); + log.set_create_pubkey(util::ByteArrayToString(pk)); + + auto resp = dynamic_cast(db.Run(&ctx, log)); + auto r = resp->create(); + ASSERT_EQ(client::CreateResponse::OK, r.status()); + evaluated_element = r.evaluated_element(); + auto [ee_data, err] = util::StringToByteArray(evaluated_element); + EXPECT_EQ(util::BytesToHex(ee_data.data(), ee_data.size()), evaluation_element_hex); + } +} + + +TEST_F(DB3Test, IETF_A_1_1_2) { + auto seed = util::HexToBytes("a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3"); + auto key_info = util::HexToBytes("74657374206b6579"); + auto sk = util::HexToBytes("5ebcea5ee37023ccb9fc2d2019f9d7737be85591ae8652ffa9ef0f4d37063b0e"); + auto input = util::HexToBytes("5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a"); + auto blind = util::HexToBytes("64d37aed22a27f5191de1c1d69fadb899d8862b58eb4220029e036ec4c1f6706"); + auto blinded_element_expected = util::HexToBytes("da27ef466870f5f15296299850aa088629945a17d1f5b7f5ff043f76b3c06418"); + auto evaluation_element_hex = "b4cbf5a4f1eeda5a63ce7b77c7d23f461db3fcab0dd28e4e17cecb5c90d02c25"; + auto output_hex = "f4a74c9c592497375e796aa837e907b1a045d34306a749db9f34221f7e750cb4f2a6413a6bf6fa5e19ba6348eb673934a722a7ede2e7621306d18951e7cf2c73"; + + // Compute blinded element + std::array blinded_element; + std::array elt = HashToGroup(input); + auto ret = crypto_scalarmult_ristretto255(blinded_element.data(), reinterpret_cast(blind.data()), elt.data()); + EXPECT_EQ(util::BytesToHex(blinded_element.data(), PUBLIC_KEY_SIZE), "da27ef466870f5f15296299850aa088629945a17d1f5b7f5ff043f76b3c06418"); + + // send to server to evaluate + std::string evaluated_element; + int tries = 3; + { + client::Log3 log; + log.set_backup_id(backup_id); + auto b = log.mutable_req()->mutable_create(); + b->set_max_tries(3); + b->set_blinded_element(util::ByteArrayToString(blinded_element)); + std::array pk{}; + CHECK(0 == crypto_scalarmult_ristretto255_base(pk.data(), reinterpret_cast(sk.data()))); + log.set_create_privkey(sk); + log.set_create_pubkey(util::ByteArrayToString(pk)); + + auto resp = dynamic_cast(db.Run(&ctx, log)); + auto r = resp->create(); + ASSERT_EQ(client::CreateResponse::OK, r.status()); + evaluated_element = r.evaluated_element(); + auto [ee_data, err] = util::StringToByteArray(evaluated_element); + EXPECT_EQ(util::BytesToHex(ee_data.data(), ee_data.size()), evaluation_element_hex); + } +} + +} // namespace svr2::db diff --git a/enclave/ecalls/ecalls.cc b/enclave/ecalls/ecalls.cc new file mode 100644 index 0000000..e83a607 --- /dev/null +++ b/enclave/ecalls/ecalls.cc @@ -0,0 +1,97 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include +#include +#include +#include +#include +#include "svr2/svr2_t.h" +#include "core/core.h" +#include "proto/error.pb.h" +#include "proto/enclaveconfig.pb.h" +#include "env/env.h" +#include "context/context.h" +#include "util/endian.h" +#include "util/log.h" +#include "metrics/metrics.h" + +namespace svr2::ecalls { +namespace { + +void SeedWeakRandom() { + LOG(INFO) << "Seeding weak randomness with strong"; + // Best-effort seeding of weak randomness from strong. + uint8_t bytes[8]; + env::environment->RandomBytes(bytes, sizeof(bytes)); + srand(util::BigEndian64FromBytes(bytes)); +} +std::unique_ptr global_core; +// Sadly, we don't appear to have access to std::shared_mutex, so we use +// the next best thing. +enum class GlobalCoreState { + UNINITIATED = 0, + INITIATING = 1, + INITIATED = 2, +}; +std::atomic global_core_state(GlobalCoreState::UNINITIATED); + +} // namespace + +extern "C" { + +int svr2_init( + size_t config_size, + unsigned char* config, + unsigned char* peer_id) { + context::Context ctx; + COUNTER(ecalls, init_calls)->Increment(); + GlobalCoreState state_expected = GlobalCoreState::UNINITIATED; + GlobalCoreState state_requested = GlobalCoreState::INITIATING; + if (!global_core_state.compare_exchange_strong(state_expected, state_requested)) { + return COUNTED_ERROR(Core_ReInit); + } + + enclaveconfig::InitConfig config_pb; + if (!config_pb.ParseFromArray(config, config_size)) { + global_core_state.store(GlobalCoreState::UNINITIATED); + return COUNTED_ERROR(Core_ConfigProtobufParse); + } + if (config_pb.initial_log_level() != enclaveconfig::LOG_LEVEL_NONE) { + util::SetLogLevel(config_pb.initial_log_level()); + } + + env::Init(config_pb.group_config().simulated()); // Can be called more than once, but never concurrently. + SeedWeakRandom(); + + LOG(INFO) << "Creating core"; + auto [core, err] = core::Core::Create(&ctx, config_pb); + if (err != error::OK) { + global_core_state.store(GlobalCoreState::UNINITIATED); + return err; + } + global_core = std::move(core); + const auto peer_id_array = global_core->ID().Get(); + std::copy(peer_id_array.begin(), peer_id_array.end(), peer_id); + global_core_state.store(GlobalCoreState::INITIATED); + return error::OK; +} + +int svr2_input_message( + size_t msg_size, + unsigned char* msg) { + context::Context ctx; + COUNTER(ecalls, host_messages_received)->Increment(); + COUNTER(ecalls, host_bytes_received)->IncrementBy(msg_size); + if (global_core_state.load() != GlobalCoreState::INITIATED) { + return COUNTED_ERROR(Core_NoInit); + } + UntrustedMessage* msg_pb = ctx.Protobuf(); + if (!msg_pb->ParseFromArray(msg, msg_size)) { + return COUNTED_ERROR(Core_ReceiveProtobufParse); + } + return global_core->Receive(&ctx, *msg_pb); +} + +} // extern "C" +} // namespace svr2::ecalls diff --git a/enclave/env/env.cc b/enclave/env/env.cc new file mode 100644 index 0000000..3d25943 --- /dev/null +++ b/enclave/env/env.cc @@ -0,0 +1,80 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include "env/env.h" +#include "util/macros.h" +#include +#include + +namespace svr2::env { + +namespace { + +class UnsetEnvironment : public Environment { + public: + virtual ~UnsetEnvironment() {} + virtual std::pair Evidence(const PublicKey& key, const enclaveconfig::RaftGroupConfig& config) const { + CHECK(nullptr == "env::Init not called, environment not initiated"); + + return std::make_pair(e2e::Attestation(), error::General_Unimplemented); + } + // Given evidence and endorsements, extract the key. + virtual std::pair Attest( + util::UnixSecs now, + const std::string& evidence, + const std::string& endorsements) const { + CHECK(nullptr == "env::Init not called, environment not initiated"); + std::array out = {0}; + return std::make_pair(out, error::General_Unimplemented); + } + // Given a string of size N, rewrite all bytes in that string with + // random bytes. + virtual error::Error RandomBytes(void* bytes, size_t size) const { + CHECK(nullptr == "env::Init not called, environment not initiated"); + return error::General_Unimplemented; + } + + virtual error::Error SendMessage(const std::string& msg) const { + CHECK(nullptr == "env::Init not called, environment not initiated"); + return error::General_Unimplemented; + } + + virtual void Log(int level, const std::string& msg) const { + CHECK(nullptr == "env::Init not called, environment not initiated"); + } + + virtual error::Error UpdateEnvStats() const { + CHECK(nullptr == "env::Init not called, environment not initiated"); + return error::General_Unimplemented; + } +}; + +const char* env_randombytes_name() { return "env"; } +uint32_t env_randombytes_uint32() { + uint32_t out; + CHECK(error::OK == environment->RandomBytes(&out, sizeof(out))); + return out; +} +void env_randombytes_bytes(void* const buf, const size_t size) { + CHECK(error::OK == environment->RandomBytes(buf, size)); +} +randombytes_implementation sodium_randombytes_impl = { + .implementation_name = env_randombytes_name, + .random = env_randombytes_uint32, + .buf = env_randombytes_bytes, +}; + +} // namespace + +std::unique_ptr environment(new UnsetEnvironment()); + +Environment::Environment() { +} + +void Environment::Init() { + // sodium_init returns 0 or 1 on success, -1 on failure. + CHECK(0 == randombytes_set_implementation(&sodium_randombytes_impl)); + CHECK(sodium_init() >= 0); +} + +} // namespace svr2::env diff --git a/enclave/env/env.h b/enclave/env/env.h new file mode 100644 index 0000000..7d9344b --- /dev/null +++ b/enclave/env/env.h @@ -0,0 +1,50 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_ENV_ENV_H__ +#define __SVR2_ENV_ENV_H__ + +#include +#include +#include "proto/error.pb.h" +#include "proto/e2e.pb.h" +#include "proto/msgs.pb.h" +#include "util/macros.h" +#include "util/ticks.h" + +namespace svr2::env { + +typedef std::array PublicKey; + +class Environment { + public: + DELETE_COPY_AND_ASSIGN(Environment); + Environment(); + virtual ~Environment() {} + virtual void Init(); + // Given a 32-byte key, return evidence of that key (an OpenEnclave report). + virtual std::pair Evidence(const PublicKey& key, const enclaveconfig::RaftGroupConfig& config) const = 0; + // Given evidence and endorsements, extract the key. + virtual std::pair Attest( + util::UnixSecs now, + const std::string& evidence, + const std::string& endorsements) const = 0; + // Given a string of size N, rewrite all bytes in that string with + // random bytes. + virtual error::Error RandomBytes(void* bytes, size_t size) const = 0; + // Send a message from enclave to host. [msg] should be a serialized + // EnclaveMessage. + virtual error::Error SendMessage(const std::string& msg) const = 0; + // Log a message to a logging framework. + virtual void Log(int level, const std::string& msg) const = 0; + // Update env-specific statistics. + virtual error::Error UpdateEnvStats() const = 0; +}; + +extern std::unique_ptr environment; + +void Init(bool is_simulated = true); + +} // namespace svr2::env + +#endif // __SVR2_ENV_ENV_H__ diff --git a/enclave/env/nsm/nsm.cc b/enclave/env/nsm/nsm.cc new file mode 100644 index 0000000..a314713 --- /dev/null +++ b/enclave/env/nsm/nsm.cc @@ -0,0 +1,114 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include +#include +#include + +#include "env/env.h" +#include "util/macros.h" +#include "context/context.h" +#include "socketwrap/socket.h" +#include "proto/nitro.pb.h" +#include "queue/queue.h" + +namespace svr2::env { +namespace nsm { +namespace { + +static queue::Queue output_messages(100); + +class Environment : public ::svr2::env::Environment { + public: + DELETE_COPY_AND_ASSIGN(Environment); + Environment() { + nsm_fd_ = nsm_lib_init(); + } + virtual ~Environment() { + nsm_lib_exit(nsm_fd_); + } + virtual std::pair Evidence(const PublicKey& key, const enclaveconfig::RaftGroupConfig& config) const { + e2e::Attestation out; + out.mutable_evidence()->resize(4096); + uint32_t evidence_len = out.evidence().size(); + std::string config_serialized; + if (!config.SerializeToString(&config_serialized)) { + return std::make_pair(out, error::Env_SerializeCustomClaims); + } + if (ERROR_CODE_SUCCESS != nsm_get_attestation_doc( + nsm_fd_, + reinterpret_cast(config_serialized.data()), + config_serialized.size(), + nullptr, + 0, + key.data(), + key.size(), + reinterpret_cast(out.mutable_evidence()->data()), + &evidence_len)) { + return std::make_pair(out, error::Env_AttestationFailure); + } + out.mutable_evidence()->resize(evidence_len); + return std::make_pair(out, error::OK); + } + + // Given evidence and endorsements, extract the key. + virtual std::pair Attest( + util::UnixSecs now, + const std::string& evidence, + const std::string& endorsements) const { + std::array out = {0}; + return std::make_pair(out, error::General_Unimplemented); + } + + // Given a string of size N, rewrite all bytes in that string with + // random bytes. + virtual error::Error RandomBytes(void* bytes, size_t size) const { + uintptr_t received; + uint8_t* u8ptr = reinterpret_cast(bytes); + while (size) { + received = size; + if (ERROR_CODE_SUCCESS != nsm_get_random(nsm_fd_, u8ptr, &received)) { + return error::Env_RandomBytes; + } + size -= received; + u8ptr += received; + } + return error::OK; + } + + virtual error::Error SendMessage(const std::string& msg) const { + output_messages.Push(msg); + return error::OK; + } + + virtual void Log(int level, const std::string& msg) const { + } + + virtual error::Error UpdateEnvStats() const { + return error::General_Unimplemented; + } + + private: + int32_t nsm_fd_; +}; + +} // namespace + +error::Error SendNsmMessages(socketwrap::Socket* sock) { + while (true) { + context::Context ctx; + for (int i = 0; i < 100; i++) { + auto out = ctx.Protobuf(); + *out->mutable_out() = output_messages.Pop(); + RETURN_IF_ERROR(sock->WritePB(&ctx, *out)); + } + } +} + +} // namespace nsm + +void Init(bool is_simulated) { + environment = std::make_unique<::svr2::env::nsm::Environment>(); +} + +} // namespace svr2::env diff --git a/enclave/env/nsm/nsm.h b/enclave/env/nsm/nsm.h new file mode 100644 index 0000000..cc35e3d --- /dev/null +++ b/enclave/env/nsm/nsm.h @@ -0,0 +1,17 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_ENV_NSM_NSM_H__ +#define __SVR2_ENV_NSM_NSM_H__ + +#include "socketwrap/socket.h" +#include "proto/error.pb.h" + +namespace svr2::env::nsm { + +// Send all outstanding messages, in order, up to the host. +error::Error SendNsmMessages(socketwrap::Socket* sock); + +} // namespace svr2::env::nsm + +#endif // __SVR2_ENV_NSM_NSM_H__ diff --git a/enclave/env/sgx/sgx.cc b/enclave/env/sgx/sgx.cc new file mode 100644 index 0000000..4384830 --- /dev/null +++ b/enclave/env/sgx/sgx.cc @@ -0,0 +1,272 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "attestation/attestation.h" +#include "env/env.h" +#include "metrics/metrics.h" +#include "svr2/svr2_t.h" +#include "util/constant.h" +#include "util/log.h" + +namespace svr2::env { +namespace sgx { + +static const char* unattested_evidence_prefix = "UNATTESTED EVIDENCE:"; +static const char* custom_claim_pk = "pk"; +static const char* custom_claim_config = "config"; +class Environment : public ::svr2::env::Environment { + public: + DELETE_COPY_AND_ASSIGN(Environment); + Environment(bool simulated) : ::svr2::env::Environment(), simulated_(simulated) { + if (!simulated_) { + CHECK(OE_OK == oe_attester_initialize()); + CHECK(OE_OK == oe_verifier_initialize()); + CHECK(error::OK == GetMRENCLAVE()); + } + } + + virtual ~Environment() { + if (!simulated_) { + oe_attester_shutdown(); + oe_verifier_shutdown(); + } + } + + virtual std::pair Evidence( + const PublicKey& key, const enclaveconfig::RaftGroupConfig& config) const { + e2e::Attestation attestation; + if (simulated_) { + attestation.set_evidence( + unattested_evidence_prefix + + std::string(reinterpret_cast(key.data()), key.size())); + return std::make_pair(attestation, error::OK); + } + std::string serialized_config; + if (!config.SerializeToString(&serialized_config)) { + return std::make_pair(e2e::Attestation(), COUNTED_ERROR(Env_SerializeConfigForEvidence)); + } + + uint8_t* custom_claims_buffer = NULL; + size_t custom_claims_buffer_size = 0; + oe_claim_t custom_claims[] = { + { + .name = const_cast(custom_claim_pk), + .value = const_cast(key.data()), + .value_size = key.size(), + }, + { + .name = const_cast(custom_claim_config), + .value = reinterpret_cast(serialized_config.data()), + .value_size = serialized_config.size(), + }, + }; + if (OE_OK != oe_serialize_custom_claims(custom_claims, sizeof(custom_claims) / sizeof(custom_claims[0]), + &custom_claims_buffer, + &custom_claims_buffer_size)) { + return std::make_pair(e2e::Attestation(), + COUNTED_ERROR(Env_SerializeCustomClaims)); + } + std::unique_ptr free_cc( + custom_claims_buffer, oe_free_serialized_custom_claims); + + uint8_t* evidence_buffer = NULL; + size_t evidence_buffer_size = 0; + uint8_t* endorsements_buffer = NULL; + size_t endorsements_buffer_size = 0; + if (OE_OK != oe_get_evidence(&attestation::sgx_remote_uuid, 0, custom_claims_buffer, + custom_claims_buffer_size, NULL, 0, + &evidence_buffer, &evidence_buffer_size, + &endorsements_buffer, + &endorsements_buffer_size)) { + return std::make_pair(e2e::Attestation(), error::Env_GetEvidence); + } + + std::unique_ptr free_evidence( + evidence_buffer, oe_free_evidence); + std::unique_ptr free_endorsements( + endorsements_buffer, oe_free_endorsements); + + std::string evidence((char*)evidence_buffer, evidence_buffer_size); + std::string endorsements((char*)endorsements_buffer, + endorsements_buffer_size); + + attestation.set_evidence(evidence); + attestation.set_endorsements(endorsements); + return std::make_pair(attestation, error::OK); + } + + virtual error::Error RandomBytes(void* bytes, size_t size) const { + CHECK(size > 0); + if (OE_OK != oe_random(bytes, size)) { + return COUNTED_ERROR(Env_RandomBytes); + } + return error::OK; + } + + virtual std::pair Attest( + util::UnixSecs now, + const std::string& evidence, + const std::string& endorsements) const { + PublicKey out = {0}; + + if (simulated_) { + if (evidence.size() != strlen(unattested_evidence_prefix) + out.size() || + evidence.substr(0, strlen(unattested_evidence_prefix)) != + unattested_evidence_prefix) { + return std::make_pair(out, error::Env_AttestationFailure); + } + memcpy(out.data(), evidence.data() + strlen(unattested_evidence_prefix), + out.size()); + return std::make_pair(out, error::OK); + } + const uint8_t* evidence_data = + reinterpret_cast(evidence.data()); + const uint8_t* endorsements_data = + reinterpret_cast(endorsements.data()); + + oe_claim_t* claims = nullptr; + size_t claims_length = 0; + + oe_datetime_t now_datetime; + SecsToOEDatetime(now, &now_datetime); + oe_policy_t policy = { + .type = OE_POLICY_ENDORSEMENTS_TIME, + .policy = &now_datetime, + .policy_size = sizeof(now_datetime), + }; + auto verify_err = oe_verify_evidence( + &attestation::sgx_remote_uuid, evidence_data, evidence.size(), endorsements_data, + endorsements.size(), &policy, 1, &claims, &claims_length); + if (OE_OK != verify_err) { + LOG(ERROR) << "oe_verify_evidence failed with code " << verify_err; + return std::make_pair(out, error::Env_AttestationFailure); + } + + auto free_claims_known_size = [claims_length](oe_claim_t* ptr) { + return oe_free_claims(ptr, claims_length); + }; + std::unique_ptr free_claims( + claims, free_claims_known_size); + + // evidence is verified, now check individual fields + error::Error err = ValidateStandardClaims(claims, claims_length); + if (error::OK != err) { + return std::make_pair(out, err); + } + + err = attestation::ReadKeyFromVerifiedClaims(claims, claims_length, out); + + return std::make_pair(out, err); + } + + virtual error::Error SendMessage(const std::string& msg) const { + if (OE_OK != + svr2_output_message( + msg.size(), const_cast( + reinterpret_cast(msg.data())))) { + return COUNTED_ERROR(Env_SendMessage); + } + return error::OK; + } + + virtual void Log(int level, const std::string& msg) const { + oe_log_ocall(level, msg.c_str()); + } + + virtual error::Error UpdateEnvStats() const { + oe_mallinfo_t info; + if (OE_OK != oe_allocator_mallinfo(&info)) { + return COUNTED_ERROR(Env_MallinfoFailure); + } + GAUGE(env, total_heap_size)->Set(info.max_total_heap_size); + GAUGE(env, allocated_heap_size)->Set(info.current_allocated_heap_size); + GAUGE(env, peak_heap_size)->Set(info.peak_allocated_heap_size); + return error::OK; + } + + private: + bool simulated_; + std::string expected_mrenclave_; + error::Error GetMRENCLAVE() { + auto [attestation, err] = Evidence(PublicKey{0}, enclaveconfig::RaftGroupConfig()); + if (err != error::OK) { + return err; + } + + auto [claims, claims_length] = attestation::VerifyAndReadClaims( + attestation.evidence(), attestation.endorsements()); + + auto free_claims_known_size = [claims_length=claims_length](oe_claim_t* ptr) { + return oe_free_claims(ptr, claims_length); + }; + std::unique_ptr free_claims( + claims, free_claims_known_size); + + // read the MRENCLAVE - this is our MRENCLAVE and we expect all peers to + // have the same value OE_CLAIM_UNIQUE_ID retrieves MRENCLAVE on SGX + const oe_claim_t* claim; + if ((claim = attestation::FindClaim(claims, claims_length, + OE_CLAIM_UNIQUE_ID)) == nullptr) { + return COUNTED_ERROR(Env_AttestationFailure); + } + expected_mrenclave_ = std::string( + reinterpret_cast(claim->value), claim->value_size); + return error::OK; + } + + error::Error ValidateStandardClaims(oe_claim_t* claims, + size_t claims_length) const { + const oe_claim_t* claim; + + // OE_CLAIM_UNIQUE_ID is MRENCLAVE for SGX + if ((claim = attestation::FindClaim(claims, claims_length, + OE_CLAIM_UNIQUE_ID)) == nullptr) { + return COUNTED_ERROR(Env_MissingMRENCLAVE); + } + auto actual_mrenclave = std::string( + reinterpret_cast(claim->value), claim->value_size); + + // Don't need constant time, but we have it so we use it. + if (!util::ConstantTimeEquals(actual_mrenclave, expected_mrenclave_)) { + return COUNTED_ERROR(Env_WrongMRENCLAVE); + } + + return error::OK; + } + + static void SecsToOEDatetime(util::UnixSecs secs, oe_datetime_t* dt) { + // Mostly copied from oe_datetime_now in OpenEnclave's common/datetime.c. + // Unfortunately, they expose the ability to get from "now", but not + // from an arbitrary timestamp. + CHECK(dt != nullptr); + struct tm timeinfo; + + gmtime_r(&secs, &timeinfo); + + dt->year = (uint32_t)timeinfo.tm_year + 1900; + dt->month = (uint32_t)timeinfo.tm_mon + 1; + dt->day = (uint32_t)timeinfo.tm_mday; + dt->hours = (uint32_t)timeinfo.tm_hour; + dt->minutes = (uint32_t)timeinfo.tm_min; + dt->seconds = (uint32_t)timeinfo.tm_sec; + } +}; + +} // namespace sgx + +void Init(bool is_simulated) { + environment = std::make_unique<::svr2::env::sgx::Environment>(is_simulated); + environment->Init(); +} + +} // namespace svr2::env diff --git a/enclave/env/test/test.cc b/enclave/env/test/test.cc new file mode 100644 index 0000000..8ef5a15 --- /dev/null +++ b/enclave/env/test/test.cc @@ -0,0 +1,94 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include "env/test/test.h" +#include "env/env.h" +#include "util/mutex.h" +#include +#include +#include +#include + +namespace svr2::env { +namespace test { + +static const char* evidence_prefix = "EVIDENCE:"; +static volatile std::atomic random_gen; + +class Environment : public ::svr2::env::Environment { + public: + DELETE_COPY_AND_ASSIGN(Environment); + Environment() : ::svr2::env::Environment() {} + virtual ~Environment() {} + virtual std::pair Evidence(const PublicKey& key, const enclaveconfig::RaftGroupConfig& config) const { + e2e::Attestation attestation; + attestation.set_evidence(evidence_prefix + std::string(reinterpret_cast(key.data()), key.size())); + return std::make_pair(attestation, error::OK); + } + + virtual error::Error RandomBytes(void* bytes, size_t size) const { + // We could do this reading in a while loop, but we expect it should be fine. + // Rewrite this if tests fail because of it. + CHECK(size > 0); + uint8_t* ptr = reinterpret_cast(bytes); + for (size_t i = 0; i < size; i++) { + uint32_t next = std::atomic_fetch_add(&random_gen, 1U); + // This keeps the sequence of bytes relatively non-repeating for the first 4GB. + *ptr++ = (uint8_t)(next ^ (next >> 8) ^ (next >> 16) ^ (next >> 24)); + } + return error::OK; + } + + virtual std::pair Attest( + util::UnixSecs now, + const std::string& evidence, + const std::string& endorsements) const { + PublicKey out = {0}; + if (evidence.size() != strlen(evidence_prefix) + out.size() + || evidence.substr(0, strlen(evidence_prefix)) != evidence_prefix) { + return std::make_pair(out, error::Env_AttestationFailure); + } + memcpy(out.data(), evidence.data() + strlen(evidence_prefix), out.size()); + return std::make_pair(out, error::OK); + } + + virtual error::Error SendMessage(const std::string& msg) const { + util::unique_lock ul(mu_); + EnclaveMessage m; + CHECK(m.ParseFromString(msg)); + sent_messages_.push_back(std::move(m)); + return error::OK; + } + + virtual void Log(int level, const std::string& msg) const { + fprintf(stderr, "%s\n", msg.c_str()); + } + + std::vector SentMessages() { + util::unique_lock ul(mu_); + return std::move(sent_messages_); + } + + virtual error::Error UpdateEnvStats() const { + return error::OK; + } + + private: + mutable util::mutex mu_; + mutable std::vector sent_messages_ GUARDED_BY(mu_); +}; + +std::vector SentMessages() { + Environment* e = dynamic_cast(::svr2::env::environment.get()); + CHECK(e != nullptr); + return e->SentMessages(); +} + +} // namespace test + +void Init(bool is_simulated) { + environment = std::make_unique<::svr2::env::test::Environment>(); + environment->Init(); +} + +} // namespace svr2::env diff --git a/enclave/env/test/test.h b/enclave/env/test/test.h new file mode 100644 index 0000000..e28844f --- /dev/null +++ b/enclave/env/test/test.h @@ -0,0 +1,17 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_ENV_TEST_TEST_H__ +#define __SVR2_ENV_TEST_TEST_H__ + +#include +#include +#include "proto/msgs.pb.h" + +namespace svr2::env::test { + +std::vector SentMessages(); + +} // namespace svr2::env::test + +#endif // __SVR2_ENV_TEST_TEST_H__ diff --git a/enclave/env/test/tests/testrand.cc b/enclave/env/test/tests/testrand.cc new file mode 100644 index 0000000..ce4fb6e --- /dev/null +++ b/enclave/env/test/tests/testrand.cc @@ -0,0 +1,28 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +//TESTDEP gtest +//TESTDEP util +//TESTDEP env +//TESTDEP env/test +//TESTDEP proto +//TESTDEP protobuf-lite +//TESTDEP libsodium + +#include +#include "env/env.h" +#include "util/log.h" +#include "util/hex.h" + +namespace svr2::env { + +TEST(EnvTest, Random) { + Init(); + uint8_t got[260]; + ASSERT_EQ(error::OK, environment->RandomBytes(got, sizeof(got))); + LOG(INFO) << "Bytes: " << util::BytesToHex(got, 8); + uint8_t expect_first[] = {0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17}; + ASSERT_EQ(0, memcmp(got, expect_first, sizeof(expect_first))); +} + +} // namespace svr2::env diff --git a/enclave/find_header.sh b/enclave/find_header.sh new file mode 100755 index 0000000..23e206c --- /dev/null +++ b/enclave/find_header.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# +# Given a compiler and a header file, return what directory that header is located in. +# +set -e +if [[ $# != 2 ]]; then + echo 1>&2 "Usage: $0
" + exit 1 +fi +COMPILER=$1 +HEADER=$2 +LISTING="" +"$COMPILER" -E -x c++ - -v &1 | while read line +do + if [[ $line == "#include <...> search starts here:" ]]; then + LISTING=1 + elif [[ $line == "End of search list." ]]; then + exit 1 + elif [[ $LISTING != "" ]]; then + if ls "$line/$HEADER" >/dev/null 2>/dev/null; then + echo "$line" + exit 0 + fi + fi +done diff --git a/enclave/googletest b/enclave/googletest new file mode 160000 index 0000000..3026483 --- /dev/null +++ b/enclave/googletest @@ -0,0 +1 @@ +Subproject commit 3026483ae575e2de942db5e760cf95e973308dd5 diff --git a/enclave/groupclock/groupclock.cc b/enclave/groupclock/groupclock.cc new file mode 100644 index 0000000..863d359 --- /dev/null +++ b/enclave/groupclock/groupclock.cc @@ -0,0 +1,52 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include "groupclock/groupclock.h" + +#include +#include +#include "util/log.h" + +namespace svr2::groupclock { + +void Clock::SetLocalTime(util::UnixSecs secs) { + local_.store(secs); +} + +void Clock::SetRemoteTime(context::Context* ctx, const peerid::PeerID& peer, util::UnixSecs secs) { + ACQUIRE_LOCK(mu_, ctx, lock_groupclock); + remotes_[peer] = secs; +} + +util::UnixSecs Clock::GetTime(context::Context* ctx, const std::set& remotes) const { + std::vector secs(1 /* local_ */ + remotes.size()); + ACQUIRE_LOCK(mu_, ctx, lock_groupclock); + auto set_iter = remotes.begin(); + auto map_iter = remotes_.begin(); + secs[0] = local_.load(); + size_t secs_size = 1; + while (set_iter != remotes.end() && map_iter != remotes_.end()) { + const peerid::PeerID& set_peer = *set_iter; + const peerid::PeerID& map_peer = map_iter->first; + if (set_peer < map_peer) { + ++set_iter; + } else if (map_peer < set_peer) { + ++map_iter; + } else { + secs[secs_size++] = map_iter->second; + ++set_iter; + ++map_iter; + } + } + secs.resize(secs_size); + // `secs` now contains a list of my timestamp and the timestamps of all + // peers in `remotes` that we've received a timestamp from. Get the median. + std::sort(secs.begin(), secs.end()); + return secs[secs.size()/2]; +} + +util::UnixSecs Clock::GetLocalTime() const { + return local_.load(); +} + +} // namespace svr2::groupclock diff --git a/enclave/groupclock/groupclock.h b/enclave/groupclock/groupclock.h new file mode 100644 index 0000000..edb635c --- /dev/null +++ b/enclave/groupclock/groupclock.h @@ -0,0 +1,37 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_GROUPCLOCK_GROUPCLOCK_H__ +#define __SVR2_GROUPCLOCK_GROUPCLOCK_H__ + +#include +#include +#include +#include "util/macros.h" +#include "util/mutex.h" +#include "util/ticks.h" +#include "peerid/peerid.h" +#include "context/context.h" + +namespace svr2::groupclock { + +// Clock that returns time based on times reported from a group of +// peers. The reported time will be the median of all reported times. +class Clock { + public: + DELETE_COPY_AND_ASSIGN(Clock); + Clock() : local_(0) {}; + void SetLocalTime(util::UnixSecs secs); + void SetRemoteTime(context::Context* ctx, const peerid::PeerID& peer, util::UnixSecs secs) EXCLUDES(mu_); + util::UnixSecs GetTime(context::Context* ctx, const std::set& remotes) const EXCLUDES(mu_); + util::UnixSecs GetLocalTime() const; + + private: + mutable util::mutex mu_; + std::atomic local_; + std::map remotes_ GUARDED_BY(mu_); +}; + +} // namespace svr2::groupclock + +#endif diff --git a/enclave/groupclock/tests/groupclock.cc b/enclave/groupclock/tests/groupclock.cc new file mode 100644 index 0000000..6b8246c --- /dev/null +++ b/enclave/groupclock/tests/groupclock.cc @@ -0,0 +1,55 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +//TESTDEP gtest +//TESTDEP peerid +//TESTDEP sip +//TESTDEP sender +//TESTDEP context +//TESTDEP env +//TESTDEP env/test +//TESTDEP util +//TESTDEP metrics +//TESTDEP proto +//TESTDEP protobuf-lite +//TESTDEP libsodium + +#include +#include "groupclock/groupclock.h" +#include "env/env.h" +#include "context/context.h" + +namespace svr2::groupclock { + +class ClockTest : public ::testing::Test { + protected: + static void SetUpTestCase() { + env::Init(); + } + context::Context ctx; +}; + +TEST_F(ClockTest, BasicUsage) { + Clock c; + EXPECT_EQ(0, c.GetTime(&ctx, std::set{})); + c.SetLocalTime(1000); + EXPECT_EQ(1000, c.GetTime(&ctx, std::set{})); + peerid::PeerID p1((uint8_t[32]){1}); + peerid::PeerID p2((uint8_t[32]){2}); + peerid::PeerID p3((uint8_t[32]){3}); + peerid::PeerID p4((uint8_t[32]){4}); + c.SetRemoteTime(&ctx, p1, 1001); + c.SetRemoteTime(&ctx, p2, 1002); + c.SetRemoteTime(&ctx, p3, 1003); + c.SetRemoteTime(&ctx, p4, 1004); + EXPECT_EQ(1001, c.GetTime(&ctx, std::set{p1})); + EXPECT_EQ(1001, c.GetTime(&ctx, std::set{p1, p2})); + EXPECT_EQ(1002, c.GetTime(&ctx, std::set{p1, p2, p3})); + EXPECT_EQ(1002, c.GetTime(&ctx, std::set{p1, p2, p3, p4})); + c.SetLocalTime(1005); + EXPECT_EQ(1003, c.GetTime(&ctx, std::set{p1, p2, p3, p4})); + c.SetRemoteTime(&ctx, p1, 1004); + EXPECT_EQ(1004, c.GetTime(&ctx, std::set{p1, p2, p3, p4})); +} + +} // namespace svr2::groupclock diff --git a/enclave/gtest/gtest-all.cc b/enclave/gtest/gtest-all.cc new file mode 120000 index 0000000..1d935ef --- /dev/null +++ b/enclave/gtest/gtest-all.cc @@ -0,0 +1 @@ +../googletest/googletest/src/gtest-all.cc \ No newline at end of file diff --git a/enclave/gtest/gtest_main.cc b/enclave/gtest/gtest_main.cc new file mode 120000 index 0000000..6ed8d70 --- /dev/null +++ b/enclave/gtest/gtest_main.cc @@ -0,0 +1 @@ +../googletest/googletest/src/gtest_main.cc \ No newline at end of file diff --git a/enclave/hmac/hmac.cc b/enclave/hmac/hmac.cc new file mode 100644 index 0000000..c89c519 --- /dev/null +++ b/enclave/hmac/hmac.cc @@ -0,0 +1,25 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include "hmac/hmac.h" +#include +#include + +namespace svr2::hmac { + +std::array Sha256(const std::string& input) { + crypto_hash_sha256_state sha; + crypto_hash_sha256_init(&sha); + crypto_hash_sha256_update(&sha, reinterpret_cast(input.data()), input.size()); + std::array out; + crypto_hash_sha256_final(&sha, out.data()); + return out; +} + +std::array HmacSha256(const std::array& key, const std::string& input) { + std::array out; + crypto_auth_hmacsha256(out.data(), reinterpret_cast(input.data()), input.size(), key.data()); + return out; +} + +} // namespace svr2::hmac diff --git a/enclave/hmac/hmac.h b/enclave/hmac/hmac.h new file mode 100644 index 0000000..1b5c67e --- /dev/null +++ b/enclave/hmac/hmac.h @@ -0,0 +1,17 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_HMAC_HMAC_H__ +#define __SVR2_HMAC_HMAC_H__ + +#include +#include + +namespace svr2::hmac { + +std::array Sha256(const std::string& input); +std::array HmacSha256(const std::array& key, const std::string& input); + +} // namespace svr2::hmac + +#endif // __SVR2_HMAC_HMAC_H__ diff --git a/enclave/hmac/tests/hmac.cc b/enclave/hmac/tests/hmac.cc new file mode 100644 index 0000000..04561f8 --- /dev/null +++ b/enclave/hmac/tests/hmac.cc @@ -0,0 +1,44 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +//TESTDEP gtest +//TESTDEP hmac +//TESTDEP noise-c +//TESTDEP libsodium + +#include +#include + +#include + +#include "hmac/hmac.h" + +namespace svr2::hmac { + +class HmacTest : public ::testing::Test { +}; + +TEST_F(HmacTest, BasicUsage) { + std::array key = { + '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', + '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', + '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', + '1', '2'}; + std::array out = HmacSha256(key, "abc"); + + // Python3: + // >>> import base64 + // >>> import hmac + // >>> import hashlib + // >>> base64.b16encode(hmac.digest(b'12345678901234567890123456789012', b'abc', hashlib.sha256)) + // b'26B7F4C64769835D3F654DC635D5362988C270883270E1EFD65372B5F3100BAF' + std::array expected = { + 0x26, 0xB7, 0xF4, 0xC6, 0x47, 0x69, 0x83, 0x5D, + 0x3F, 0x65, 0x4D, 0xC6, 0x35, 0xD5, 0x36, 0x29, + 0x88, 0xC2, 0x70, 0x88, 0x32, 0x70, 0xE1, 0xEF, + 0xD6, 0x53, 0x72, 0xB5, 0xF3, 0x10, 0x0B, 0xAF, + }; + EXPECT_EQ(out, expected); +} + +} // namespace svr2::hmac diff --git a/enclave/libsodium b/enclave/libsodium new file mode 160000 index 0000000..fd5cbe9 --- /dev/null +++ b/enclave/libsodium @@ -0,0 +1 @@ +Subproject commit fd5cbe9e696c1b886e45f3111dd099d51b12de6e diff --git a/enclave/metrics/counters.h b/enclave/metrics/counters.h new file mode 100644 index 0000000..a7f6c6f --- /dev/null +++ b/enclave/metrics/counters.h @@ -0,0 +1,114 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +// This file contains all counter metrics used within SVR2. +// +// They're created with the macro CREATE_COUNTER, which takes arguments: +// * ns - namespace of the counter (generally, module name) +// * varname - name of the variable used to reference this counter, must be +// unique within the namespace (ns) +// * name - name of the exported variable (actually, "ns.name") +// * tags - set of tags associated with this variable, either empty `({})`, or +// an initializer list `({{"foo", "bar"}, {"baz", "blah"}})` for tags +// foo=bar, baz=blah. Must be wrapped in parens. +// +// Once these counters are created here, they're used with the incantation: +// COUNTER(ns, varname)->CounterFunction(); +// IE: +// COUNTER(sender, enclave_messages_sent)->IncrementBy(3); +// +// All counters created here will be exported to the host, even if they are +// zero. This differs from error counts, which are exported only if non-zero. + +CREATE_COUNTER(ecalls, host_messages_received, host_messages_received, ({})) +CREATE_COUNTER(ecalls, host_bytes_received, host_bytes_received, ({})) +CREATE_COUNTER(ecalls, init_calls, init_calls, ({})) + +CREATE_COUNTER(sender, enclave_messages_sent, enclave_messages_sent, ({})) +CREATE_COUNTER(sender, enclave_bytes_sent, enclave_bytes_sent, ({})) + +CREATE_COUNTER(core, host_requests_received, msgs_received, ({{"type", "host_request"}})) +CREATE_COUNTER(core, peer_msgs_received, msgs_received, ({{"type", "peer_message"}})) +CREATE_COUNTER(core, timer_ticks_received, msgs_received, ({{"type", "timer_tick"}})) +CREATE_COUNTER(core, invalid_msgs_received, msgs_received, ({{"type", "invalid"}})) +CREATE_COUNTER(core, new_client_success, new_clients, ({{"outcome", "success"}})) +CREATE_COUNTER(core, new_client_failure, new_clients, ({{"outcome", "failure"}})) +CREATE_COUNTER(core, log_transactions_success, log_transactions, ({{"outcome", "success"}})) +CREATE_COUNTER(core, log_transactions_cancelled, log_transactions, ({{"outcome", "cancelled"}})) +CREATE_COUNTER(core, host_delete_success, host_delete, ({{"outcome", "success"}})) +CREATE_COUNTER(core, host_delete_failure, host_delete, ({{"outcome", "failure"}})) +CREATE_COUNTER(core, client_transaction_success, client_transaction, ({{"outcome", "success"}})) +CREATE_COUNTER(core, client_transaction_cancelled, client_transaction, ({{"outcome", "cancelled"}})) +CREATE_COUNTER(core, client_transaction_error, client_transaction, ({{"outcome", "error"}})) +CREATE_COUNTER(core, client_transaction_invalid, client_transaction, ({{"outcome", "invalid"}})) +CREATE_COUNTER(core, client_transaction_dne, client_transaction, ({{"outcome", "dne"}})) +CREATE_COUNTER(core, client_transaction_encrypterr, client_transaction, ({{"outcome", "encrypterr"}})) +CREATE_COUNTER(core, raft_log_applied, raft_log_applied, ({})) + +CREATE_COUNTER(client, created, created, ({})) +CREATE_COUNTER(client, closed, closed, ({})) +CREATE_COUNTER(client, new_dh_state, new_dh_state, ({})) +CREATE_COUNTER(client, key_rotate_success, key_rotate, ({{"outcome", "success"}})) +CREATE_COUNTER(client, key_rotate_failure, key_rotate, ({{"outcome", "failure"}})) +CREATE_COUNTER(client, attestation_refresh_success, attestation_refresh, ({{"outcome", "success"}})) +CREATE_COUNTER(client, attestation_refresh_failure, attestation_refresh, ({{"outcome", "failure"}})) + +CREATE_COUNTER(peers, attestation_refresh_success, attestation_refresh, ({{"outcome", "success"}})) +CREATE_COUNTER(peers, attestation_refresh_failure, attestation_refresh, ({{"outcome", "failure"}})) + +CREATE_COUNTER(raft, logs_committed, logs_committed, ({})) +CREATE_COUNTER(raft, logs_promised, logs_promised, ({})) +CREATE_COUNTER(raft, vote_requests_received, msgs_received, ({{"type", "vote_request"}})) +CREATE_COUNTER(raft, vote_responses_received, msgs_received, ({{"type", "vote_response"}})) +CREATE_COUNTER(raft, append_requests_received, msgs_received, ({{"type", "append_request"}})) +CREATE_COUNTER(raft, append_responses_received, msgs_received, ({{"type", "append_response"}})) +CREATE_COUNTER(raft, timeout_nows_received, msgs_received, ({{"type", "timeout_now"}})) +CREATE_COUNTER(raft, invalid_requests_received, msgs_received, ({{"type", "invalid"}})) +CREATE_COUNTER(raft, term_updated, term_updated, ({})) +CREATE_COUNTER(raft, term_increments, term_increments, ({})) +CREATE_COUNTER(raft, logs_append_success, logs_appended, ({{"outcome", "success"}})) +CREATE_COUNTER(raft, logs_append_failure, logs_appended, ({{"outcome", "failure"}})) +CREATE_COUNTER(raft, election_timeouts, election_timeouts, ({})) + +CREATE_COUNTER(timeout, timeouts_created, timeouts_created, ({})) +CREATE_COUNTER(timeout, timeouts_run, timeouts_completed, ({{"outcome", "run"}})) +CREATE_COUNTER(timeout, timeouts_cancelled, timeouts_completed, ({{"outcome", "cancelled"}})) + +CREATE_COUNTER(context, cpu_uncategorized, cpu, ({{"in", "uncategorized"}, {"action", "uncategorized"}})) +CREATE_COUNTER(context, cpu_client_encrypt, cpu, ({{"in", "client"}, {"action", "encrypt"}})) +CREATE_COUNTER(context, cpu_client_decrypt, cpu, ({{"in", "client"}, {"action", "decrypt"}})) +CREATE_COUNTER(context, cpu_client_hs_start, cpu, ({{"in", "client"}, {"action", "hs_start"}})) +CREATE_COUNTER(context, cpu_client_hs_finish, cpu, ({{"in", "client"}, {"action", "hs_finish"}})) +CREATE_COUNTER(context, cpu_peer_encrypt, cpu, ({{"in", "peer"}, {"action", "encrypt"}})) +CREATE_COUNTER(context, cpu_peer_decrypt, cpu, ({{"in", "peer"}, {"action", "decrypt"}})) +CREATE_COUNTER(context, cpu_peer_connect, cpu, ({{"in", "peer"}, {"action", "connect"}})) +CREATE_COUNTER(context, cpu_peer_connect2, cpu, ({{"in", "peer"}, {"action", "connect2"}})) +CREATE_COUNTER(context, cpu_peer_accept, cpu, ({{"in", "peer"}, {"action", "accept"}})) +CREATE_COUNTER(context, cpu_db_client_request, cpu, ({{"in", "db"}, {"action", "client_request"}})) +CREATE_COUNTER(context, cpu_db_repl_send, cpu, ({{"in", "db"}, {"action", "repl_send"}})) +CREATE_COUNTER(context, cpu_db_repl_recv, cpu, ({{"in", "db"}, {"action", "repl_recv"}})) +CREATE_COUNTER(context, cpu_db_hash, cpu, ({{"in", "db"}, {"action", "hash"}})) +CREATE_COUNTER(context, cpu_core_client_msg, cpu, ({{"in", "core"}, {"action", "client_msg"}})) +CREATE_COUNTER(context, cpu_core_peer_msg, cpu, ({{"in", "core"}, {"action", "peer_msg"}})) +CREATE_COUNTER(context, cpu_core_host_msg, cpu, ({{"in", "core"}, {"action", "host_msg"}})) +CREATE_COUNTER(context, cpu_core_raft_msg, cpu, ({{"in", "core"}, {"action", "raft_msg"}})) +CREATE_COUNTER(context, cpu_core_e2e_txn_req, cpu, ({{"in", "core"}, {"action", "e2e_txn_req"}})) +CREATE_COUNTER(context, cpu_core_e2e_txn_resp, cpu, ({{"in", "core"}, {"action", "e2e_txn_resp"}})) +CREATE_COUNTER(context, cpu_core_repl_send, cpu, ({{"in", "core"}, {"action", "repl_send"}})) +CREATE_COUNTER(context, cpu_core_repl_recv, cpu, ({{"in", "core"}, {"action", "repl_recv"}})) +CREATE_COUNTER(context, cpu_core_committed_logs, cpu, ({{"in", "core"}, {"action", "committed_logs"}})) +CREATE_COUNTER(context, cpu_core_timer_tick, cpu, ({{"in", "core"}, {"action", "timer_tick"}})) +CREATE_COUNTER(context, cpu_test_database_entries, cpu, ({{"in", "core"}, {"action", "test_database_entries"}})) +CREATE_COUNTER(context, lock_core_raft, cpu, ({{"in", "core"}, {"action", "lock"}, {"lock", "core_raft"}})) +CREATE_COUNTER(context, lock_core_log_txns, cpu, ({{"in", "core"}, {"action", "lock"}, {"lock", "core_log_txns"}})) +CREATE_COUNTER(context, lock_core_e2e_txns, cpu, ({{"in", "core"}, {"action", "lock"}, {"lock", "core_e2e_txns"}})) +CREATE_COUNTER(context, lock_core_config, cpu, ({{"in", "core"}, {"action", "lock"}, {"lock", "core_config"}})) +CREATE_COUNTER(context, lock_groupclock, cpu, ({{"in", "groupclock"}, {"action", "lock"}, {"lock", "groupclock"}})) +CREATE_COUNTER(context, lock_timeout, cpu, ({{"in", "timeout"}, {"action", "lock"}, {"lock", "timeout"}})) +CREATE_COUNTER(context, lock_peermanager, cpu, ({{"in", "peer"}, {"action", "lock"}, {"lock", "peermanager"}})) +CREATE_COUNTER(context, lock_peer, cpu, ({{"in", "peer"}, {"action", "lock"}, {"lock", "peer"}})) +CREATE_COUNTER(context, lock_clientmanager, cpu, ({{"in", "client"}, {"action", "lock"}, {"lock", "clientmanager"}})) +CREATE_COUNTER(context, lock_client, cpu, ({{"in", "client"}, {"action", "lock"}, {"lock", "client"}})) +CREATE_COUNTER(context, lock_test, cpu, ({{"in", "test"}})) +CREATE_COUNTER(context, lock_socket_read, socket, ({{"in", "socket"}, {"action", "lock"}, {"lock", "read"}})) +CREATE_COUNTER(context, lock_socket_write, socket, ({{"in", "socket"}, {"action", "lock"}, {"lock", "write"}})) diff --git a/enclave/metrics/gauges.h b/enclave/metrics/gauges.h new file mode 100644 index 0000000..e1cfb6e --- /dev/null +++ b/enclave/metrics/gauges.h @@ -0,0 +1,51 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +// This file contains all gauge metrics used within SVR2. +// +// They're created with the macro CREATE_GAUGE, which takes arguments: +// * ns - namespace of the gauge (generally, module name) +// * varname - name of the variable used to reference this gauge, must be +// unique within the namespace (ns). Also the exported name. +// +// Once these gauges are created here, they're used with the incantation: +// GAUGE(ns, varname)->GaugeFunction(); +// IE: +// GAUGE(sender, enclave_messages_sent)->Set(12); +// +// Gauges are only exported after their first Set call, to avoid sending up +// spurious invalid values to metrics. If Clear is called, they will no longer +// be exported. + +CREATE_GAUGE(raft, role) +CREATE_GAUGE(raft, is_voting) +CREATE_GAUGE(raft, current_term) +CREATE_GAUGE(raft, commit_index) +CREATE_GAUGE(raft, promise_index) +CREATE_GAUGE(raft, log_oldest_stored_log_index) +CREATE_GAUGE(raft, log_last_log_term) +CREATE_GAUGE(raft, log_last_log_index) +CREATE_GAUGE(raft, log_size) +CREATE_GAUGE(raft, log_total_size) +CREATE_GAUGE(raft, log_entries) + +CREATE_GAUGE(core, raft_state) +CREATE_GAUGE(core, last_index_applied_to_db) +CREATE_GAUGE(core, current_local_time) +CREATE_GAUGE(core, current_groupclock_time) + + +CREATE_GAUGE(peers, peers) + +CREATE_GAUGE(client, clients) + +CREATE_GAUGE(db, rows) + +CREATE_GAUGE(timeout, timeouts) + +CREATE_GAUGE(test, test1) +CREATE_GAUGE(test, test2) + +CREATE_GAUGE(env, total_heap_size) +CREATE_GAUGE(env, allocated_heap_size) +CREATE_GAUGE(env, peak_heap_size) diff --git a/enclave/metrics/metrics.cc b/enclave/metrics/metrics.cc new file mode 100644 index 0000000..745bb2d --- /dev/null +++ b/enclave/metrics/metrics.cc @@ -0,0 +1,100 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include "metrics/metrics.h" + +namespace svr2::metrics { + +namespace { +static std::atomic recorded_errors[error::Error_ARRAYSIZE] = {0}; +} // namespace + +MetricsPB AllAsPB() { + MetricsPB out; + for (int i = 0; i < error::Error_ARRAYSIZE; i++) { + if (error::Error_IsValid(i)) { + uint64_t v = recorded_errors[i].load(); + if (v > 0) { + U64PB* counter = out.add_counters(); + counter->set_name("errors"); + (*counter->mutable_tags())["error"] = error::Error_Name(i); + counter->set_v(v); + } + } + } + for (int i = 0; i < COUNTERS_ARRAY_SIZE; i++) { + internal::counters[i].AddToMetrics(&out); + } + for (int i = 0; i < GAUGES_ARRAY_SIZE; i++) { + internal::gauges[i].AddToMetrics(&out); + } + return out; +} + +void ClearAllForTest() { + for (int i = 0; i < error::Error_ARRAYSIZE; i++) { + recorded_errors[i].store(0); + } + for (int i = 0; i < COUNTERS_ARRAY_SIZE; i++) { + internal::counters[i].Clear(); + } +} + +Counter::Counter(const std::string& name, std::map&& tags) + : name_(name), tags_(tags) {} + +void Counter::IncrementBy(uint64_t v) { + v_.fetch_add(v); +} + +void Counter::AddToMetrics(MetricsPB* pb) { + auto c = pb->add_counters(); + c->set_name(name_); + c->set_v(v_.load()); + for (auto iter = tags_.cbegin(); iter != tags_.cend(); ++iter) { + (*c->mutable_tags())[iter->first] = iter->second; + } +} + +void Counter::Clear() { + v_.store(0); +} + +Gauge::Gauge(const std::string& name) + : v_(UINT64_MAX), name_(name) {} + +void Gauge::Set(uint64_t v) { + v_.store(v); +} + +void Gauge::AddToMetrics(MetricsPB* pb) { + uint64_t v = v_.load(); + if (v == UINT64_MAX) { return; } + auto c = pb->add_gauges(); + c->set_name(name_); + c->set_v(v); +} + +void Gauge::Clear() { + v_.store(UINT64_MAX); +} + +namespace internal { +error::Error RecordError(error::Error e) { + recorded_errors[e].fetch_add(1); + return e; +} + +Counter counters[COUNTERS_ARRAY_SIZE] = { +#define CREATE_COUNTER(ns, varname, name, tags) Counter(#ns "." #name, std::maptags), +#include "counters.h" +#undef CREATE_COUNTER +}; +Gauge gauges[GAUGES_ARRAY_SIZE] = { +#define CREATE_GAUGE(ns, name) Gauge(#ns "." #name), +#include "gauges.h" +#undef CREATE_GAUGE +}; +} // namespace internal + +} // namespace svr2::metrics diff --git a/enclave/metrics/metrics.h b/enclave/metrics/metrics.h new file mode 100644 index 0000000..a11affd --- /dev/null +++ b/enclave/metrics/metrics.h @@ -0,0 +1,92 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_METRICS_METRICS_H__ +#define __SVR2_METRICS_METRICS_H__ + +#include +#include + +#include "proto/metrics.pb.h" +#include "proto/error.pb.h" + +namespace svr2::metrics { + +// Export all global metrics as a single protobuf. +MetricsPB AllAsPB(); + +// Return all global metrics to an initial state. For testing only. +void ClearAllForTest(); + +// A counter provides a simple, atomic counter object that monotonically increases. +// We do not protect against overflows, but given that this is a 64-bit value, they +// would be pretty impressive. +class Counter { + public: + Counter(const std::string& name, std::map&& tags); + void IncrementBy(uint64_t v); + inline void Increment() { IncrementBy(1); } + private: + friend MetricsPB AllAsPB(); + friend void ClearAllForTest(); + void AddToMetrics(MetricsPB* pb); + void Clear(); + std::atomic v_; + const std::string name_; + const std::map tags_; +}; + +// A gauge provides a simple, atomic gauge object that can be set to arbitrary +// values. We save UINT64_MAX as a special invalid value. +class Gauge { + public: + Gauge(const std::string& name); + void Set(uint64_t v); + void Clear(); + private: + friend MetricsPB AllAsPB(); + friend void ClearAllForTest(); + void AddToMetrics(MetricsPB* pb); + std::atomic v_; + const std::string name_; +}; + +// We use the somewhat tricky counters.h/gauges.h file to generate a set of metricss +// that are both accessible to the rest of the code and iterable by this code. +// In short, we use a CREATE_COUNTER/CREATE_GAUGE macros, which we define/include/undef, +// both here and in metrics.cc, to generate the header and source parts of the metrics. +enum Counters { +#define CREATE_COUNTER(ns, varname, name, tags) CTR__##ns##__##varname, +#include "counters.h" +#undef CREATE_COUNTER + COUNTERS_ARRAY_SIZE, +}; +enum Gauges { +#define CREATE_GAUGE(ns, name) GAG__##ns##__##name, +#include "gauges.h" +#undef CREATE_GAUGE + GAUGES_ARRAY_SIZE, +}; + +namespace internal { +error::Error RecordError(error::Error); +extern Counter counters[COUNTERS_ARRAY_SIZE]; +extern Gauge gauges[GAUGES_ARRAY_SIZE]; +} // namespace internal + +} // namespace svr2::metrics + +// COUNTER(ns, name) returns a pointer to a metrics::Counter based on the +// counter namespace/name as created in counters.h. +#define COUNTER(ns, name) (&::svr2::metrics::internal::counters[::svr2::metrics::CTR__##ns##__##name]) + +// GAUGE(ns, name) returns a pointer to a metrics::Gauge based on the +// gauge namespace/name as created in gauges.h. +#define GAUGE(ns, name) (&::svr2::metrics::internal::gauges[::svr2::metrics::GAG__##ns##__##name]) + +// COUNTED_ERROR counts an error within metrics, returning that same error. +// It's generally used like: +// return COUNTED_ERROR(Foo_Bar); +#define COUNTED_ERROR(x) ::svr2::metrics::internal::RecordError(error::x) + +#endif // __SVR2_METRICS_METRICS_H__ diff --git a/enclave/metrics/tests/metrics.cc b/enclave/metrics/tests/metrics.cc new file mode 100644 index 0000000..9c24278 --- /dev/null +++ b/enclave/metrics/tests/metrics.cc @@ -0,0 +1,122 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +//TESTDEP gtest +//TESTDEP metrics +//TESTDEP proto +//TESTDEP protobuf-lite + +#include +#include "proto/error.pb.h" +#include "metrics/metrics.h" + +namespace svr2::metrics { + +class MetricsTest : public ::testing::Test { + protected: + void SetUp() { + ClearAllForTest(); + } + + const int FindCounter(const MetricsPB& pb, const std::string& name, const std::map& tags) { + for (int i = 0; i < pb.counters_size(); i++) { + auto c = pb.counters(i); + if (name != c.name() || tags.size() != c.tags().size()) { + continue; + } + bool tags_equal = true; + for (auto iter = tags.cbegin(); iter != tags.cend() && tags_equal; ++iter) { + if (c.tags().count(iter->first) == 0 || + c.tags().at(iter->first) != iter->second) { + tags_equal = false; + break; + } + } + if (!tags_equal) { continue; } + return i; + } + return -1; + } + const int FindGauge(const MetricsPB& pb, const std::string& name) { + for (int i = 0; i < pb.gauges_size(); i++) { + auto c = pb.gauges(i); + if (name == c.name()) { return i; } + } + return -1; + } +}; + +error::Error ReturnsGeneralUnimplemented() { + return COUNTED_ERROR(General_Unimplemented); +} + +error::Error ReturnsCoreReInit() { + return COUNTED_ERROR(Core_ReInit); +} + +TEST_F(MetricsTest, CountsReturnedErrors) { + for (int i = 0; i < 3; i++) { + ReturnsGeneralUnimplemented(); + } + MetricsPB got = AllAsPB(); + ASSERT_EQ(got.counters_size(), 1 + COUNTERS_ARRAY_SIZE); + auto c = got.counters(0); + ASSERT_EQ(c.v(), 3); + ASSERT_EQ(c.tags().at("error"), "General_Unimplemented"); + for (int i = 0; i < 5; i++) { + ReturnsCoreReInit(); + } + got = AllAsPB(); + ASSERT_EQ(got.counters_size(), 2 + COUNTERS_ARRAY_SIZE); + c = got.counters(0); + ASSERT_EQ(c.v(), 3); + ASSERT_EQ(c.tags().at("error"), "General_Unimplemented"); + c = got.counters(1); + ASSERT_EQ(c.v(), 5); + ASSERT_EQ(c.tags().at("error"), "Core_ReInit"); +} + +TEST_F(MetricsTest, Counters) { + COUNTER(core, peer_msgs_received)->Increment(); + COUNTER(core, peer_msgs_received)->Increment(); + COUNTER(core, peer_msgs_received)->Increment(); + MetricsPB got = AllAsPB(); + int i = FindCounter(got, "core.msgs_received", {{"type", "peer_message"}}); + ASSERT_GE(i, 0); + auto c = got.counters(i); + ASSERT_EQ(c.name(), "core.msgs_received"); + ASSERT_EQ(c.tags().size(), 1); + ASSERT_EQ(c.tags().at("type"), "peer_message"); + ASSERT_EQ(c.v(), 3); +} + +TEST_F(MetricsTest, Gauges) { + MetricsPB got = AllAsPB(); + ASSERT_EQ(got.gauges_size(), 0); + GAUGE(test, test1)->Set(123); + got = AllAsPB(); + ASSERT_EQ(got.gauges_size(), 1); + EXPECT_EQ(got.gauges(0).name(), "test.test1"); + EXPECT_EQ(got.gauges(0).v(), 123); + GAUGE(test, test2)->Set(234); + GAUGE(test, test1)->Set(345); + got = AllAsPB(); + ASSERT_EQ(got.gauges_size(), 2); + int t1 = FindGauge(got, "test.test1"); + int t2 = FindGauge(got, "test.test2"); + ASSERT_GE(t1, 0); + ASSERT_GE(t2, 0); + auto g1 = got.gauges(t1); + auto g2 = got.gauges(t2); + EXPECT_EQ(g1.name(), "test.test1"); + EXPECT_EQ(g1.v(), 345); + EXPECT_EQ(g2.name(), "test.test2"); + EXPECT_EQ(g2.v(), 234); + GAUGE(test, test1)->Clear(); + got = AllAsPB(); + ASSERT_EQ(got.gauges_size(), 1); + EXPECT_EQ(got.gauges(0).name(), "test.test2"); + EXPECT_EQ(got.gauges(0).v(), 234); +} + +} // namespace svr2::metrics diff --git a/enclave/nitromain/nitromain.cc b/enclave/nitromain/nitromain.cc new file mode 100644 index 0000000..2424446 --- /dev/null +++ b/enclave/nitromain/nitromain.cc @@ -0,0 +1,144 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include +#include +#include +#include +#include + +#include "env/env.h" +#include "core/core.h" +#include "context/context.h" +#include "proto/enclaveconfig.pb.h" +#include "util/log.h" +#include "util/bytes.h" +#include "proto/nitro.pb.h" +#include "socketwrap/socket.h" +#include "env/nsm/nsm.h" +#include "queue/queue.h" + +namespace svr2 { + +#define RETURN_ERRNO_ERROR_IF(x, err) do { \ + if ((x)) { \ + int e = errno; \ + LOG(ERROR) << "(" << #x << ") evaluated to false, errno(" << e << "): " << strerror(e); \ + return COUNTED_ERROR(err); \ + } \ +} while (0) + +// To simplify our server, this function creates the appropriate +// AF_VSOCK, binds it, listens, accepts, then returns the accepted +// file descriptor, closing the listener. We know that if this +// socket dies, we stop serving, so there's no need to create an +// accept loop. +error::Error AcceptSocket(int* afd) { + int fd; + RETURN_ERRNO_ERROR_IF( + 0 >= (fd = socket(AF_VSOCK, SOCK_STREAM, 0)), + Nitro_SocketCreation); + + struct sockaddr_vm my_addr; + memset(&my_addr, 0, sizeof(my_addr)); + my_addr.svm_family = AF_VSOCK; + my_addr.svm_port = VMADDR_PORT_ANY; + my_addr.svm_cid = VMADDR_CID_ANY; + RETURN_ERRNO_ERROR_IF( + 0 != bind(fd, (struct sockaddr *) &my_addr, sizeof(my_addr)), + Nitro_SocketBind); + RETURN_ERRNO_ERROR_IF( + 0 != listen(fd, 2), + Nitro_SocketListen); + + *afd = 0; + while (*afd <= 0) { + struct sockaddr_vm remote_addr; + socklen_t remote_len = sizeof(remote_addr); + *afd = accept4(fd, reinterpret_cast(&remote_addr), &remote_len, SOCK_CLOEXEC); + RETURN_ERRNO_ERROR_IF( + *afd <= 0 && errno != EINTR && errno != ECONNABORTED, + Nitro_SocketAccept); + } + shutdown(fd, SHUT_RDWR); + close(fd); + return error::OK; +} + +error::Error RunServerThread(core::Core* core, socketwrap::Socket* sock) { + while (true) { + context::Context ctx; + auto in = ctx.Protobuf(); + RETURN_IF_ERROR(sock->ReadPB(&ctx, in)); + if (in->inner_case() != nitro::InboundMessage::kMsg) { + return COUNTED_ERROR(Nitro_InboundNotMessage); + } + auto msg = ctx.Protobuf(); + if (!msg->ParseFromString(in->mutable_msg()->data())) { + return COUNTED_ERROR(Nitro_InboundMessageParse); + } + auto status = core->Receive(&ctx, *msg); + auto out = ctx.Protobuf(); + auto out_msg = out->mutable_msg(); + out_msg->set_id(in->msg().id()); + out_msg->set_status(status); + RETURN_IF_ERROR(sock->WritePB(&ctx, *out)); + } +} + +// Read an init message from a socket and use it to create a new core object. +std::pair, error::Error> InitCore(socketwrap::Socket* sock) { + context::Context ctx; + auto init = ctx.Protobuf(); + if (error::Error err = sock->ReadPB(&ctx, init); err != error::OK) { + return std::make_pair(nullptr, err); + } + if (init->inner_case() != nitro::InboundMessage::kInit) { + return std::make_pair(nullptr, COUNTED_ERROR(Nitro_InboundNotInit)); + } + auto [core_ptr, err] = core::Core::Create( + &ctx, + init->init()); + if (err == error::OK) { + auto out = ctx.Protobuf(); + core_ptr->ID().ToString(out->mutable_init()->mutable_peer_id()); + err = sock->WritePB(&ctx, *out); + } + return std::make_pair(std::move(core_ptr), err); +} + +// Run a server, returning an error when it dies. +error::Error RunServer() { + int fd; + RETURN_IF_ERROR(AcceptSocket(&fd)); + socketwrap::Socket sock(fd); + auto sockp = &sock; + std::vector threads; + threads.emplace_back([sockp]{ + LOG(FATAL) << env::nsm::SendNsmMessages(sockp); + }); + auto [c, err] = InitCore(&sock); + RETURN_IF_ERROR(err); + auto cp = c.get(); + for (size_t i = 0; i < 32 /* chosen by random dice roll */; i++) { + threads.emplace_back([cp, sockp]{ + LOG(FATAL) << RunServerThread(cp, sockp); + }); + } + for (size_t i = 0; i < threads.size(); i++) { + threads[i].join(); + } + return error::OK; // unreachable +} + +error::Error Run() { + env::Init(); + return RunServer(); +} + +} // namespace svr2 + +int main(int argc, char** argv) { + LOG(FATAL) << svr2::Run(); + return -1; +} diff --git a/enclave/noise-c b/enclave/noise-c new file mode 160000 index 0000000..3541938 --- /dev/null +++ b/enclave/noise-c @@ -0,0 +1 @@ +Subproject commit 354193847d04475e474a89dbb11b6434e1d9cbca diff --git a/enclave/noise/noise.cc b/enclave/noise/noise.cc new file mode 100644 index 0000000..0290f76 --- /dev/null +++ b/enclave/noise/noise.cc @@ -0,0 +1,73 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include "noise/noise.h" +#include +#include "util/log.h" +#include "metrics/metrics.h" + +namespace svr2::noise { + +static size_t max_message_size = 65535; + +std::pair Encrypt(NoiseCipherState* cs, const std::string& plaintext) { + std::string ciphertext; + size_t mac_size = noise_cipherstate_get_mac_length(cs); + size_t max_encrypt_size = max_message_size - mac_size; + size_t orig_size = plaintext.size(); + // We need to fit our plaintext into some number of Noise output packets. + // Each of those packets cannot be larger than max_message_size, and must + // contain some amount of ciphertext along with Noise's added MAC. + // Thus, we have to add some amount of size equivilent to a multiple of + // mac_size to the size of plaintext to get the final size of *ciphertext. + // Examples of input sizes and output sizes, around the max_message_size + // boundary, are: + // size == 1 : add mac_size * 1 -> [cleartext(1B)][mac] + // size == max_message_size - mac_size : add mac_size * 1 -> [cleartext(max_msg_sizeB)][mac] + // size == max_message_size - mac_size + 1 : add mac_size * 2 -> [cleartext(max_msg_sizeB)][mac1][cleartext(1B)][mac2] + size_t num_macs = orig_size / max_encrypt_size + 1; + if (orig_size % max_encrypt_size == 0 && num_macs > 1) num_macs--; + size_t macs_size = mac_size * num_macs; + size_t final_size = orig_size + macs_size; + ciphertext.resize(final_size, 0); + size_t plaintext_start = 0; + for (size_t start = 0; start < final_size; start += max_message_size) { + size_t plaintext_size = std::min(max_encrypt_size, plaintext.size() - plaintext_start); + memcpy(StrU8Ptr(&ciphertext) + start, StrU8Ptr(plaintext) + plaintext_start, plaintext_size); + plaintext_start += plaintext_size; + NoiseBuffer buf; + noise_buffer_set_inout(buf, StrU8Ptr(&ciphertext) + start, plaintext_size, plaintext_size + mac_size); + if (NOISE_ERROR_NONE != noise_cipherstate_encrypt(cs, &buf)) { + return std::make_pair("", COUNTED_ERROR(Peers_Encrypt)); + } + } + return std::make_pair(ciphertext, error::OK); +} + +std::pair Decrypt(NoiseCipherState* cs, const std::string& ciphertext) { + std::string plaintext(ciphertext.size(), 0); + size_t plaintext_start = 0; + // Data comes in as [ciphertext][mac][ciphertext][mac]. + for (size_t start = 0; start < ciphertext.size(); start += max_message_size) { + size_t size = std::min(max_message_size, ciphertext.size() - start); + memcpy(StrU8Ptr(&plaintext) + plaintext_start, StrU8Ptr(ciphertext) + start, size); + NoiseBuffer buf; + noise_buffer_set_inout(buf, StrU8Ptr(&plaintext) + plaintext_start, size, size); + if (NOISE_ERROR_NONE != noise_cipherstate_decrypt(cs, &buf)) { + return std::make_pair("", COUNTED_ERROR(Peers_Decrypt)); + } + plaintext_start += buf.size; + } + plaintext.resize(plaintext_start, 0); + return std::make_pair(plaintext, error::OK); +} + +DHState CloneDHState(const DHState& s) { + NoiseDHState* sp = nullptr; + auto dh_id = noise_dhstate_get_dh_id(s.get()); + CHECK(NOISE_ERROR_NONE == noise_dhstate_new_by_id(&sp, dh_id)); + CHECK(NOISE_ERROR_NONE == noise_dhstate_copy(sp, s.get())); + return WrapDHState(sp); +} + +} // namespace svr2::noise diff --git a/enclave/noise/noise.h b/enclave/noise/noise.h new file mode 100644 index 0000000..c4a43c4 --- /dev/null +++ b/enclave/noise/noise.h @@ -0,0 +1,76 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_NOISE_NOISE_H__ +#define __SVR2_NOISE_NOISE_H__ + +#include +#include +#include +#include +#include +#include +#include "proto/error.pb.h" + +// This module provides simple RAII wrappers around noise-c pointers. +// The pointers are exposed publicly as .state, allowing use of noise_* functions +// directly on them, but with the guarantee that when the *State objects fall +// out of scope, the correct noise_*_free function will be called on them. + +#include "util/macros.h" + +namespace svr2::noise { + +const size_t HANDSHAKE_INIT_SIZE = 64; + +inline uint8_t* StrU8Ptr(std::string* s) { + return reinterpret_cast(s->data()); +} +inline const uint8_t* StrU8Ptr(const std::string& s) { + return reinterpret_cast(s.data()); +} + +inline NoiseBuffer BufferOutputFromString(std::string* s) { + NoiseBuffer b; + noise_buffer_set_output(b, StrU8Ptr(s), s->size()); + return b; +} + +inline NoiseBuffer BufferInputFromString(std::string* s) { + NoiseBuffer b; + noise_buffer_set_input(b, StrU8Ptr(s), s->size()); + return b; +} + +inline NoiseBuffer BufferInoutFromString(std::string* s, size_t substr) { + CHECK(substr <= s->size()); + NoiseBuffer b; + noise_buffer_set_inout(b, StrU8Ptr(s), substr, s->size()); + return b; +} + +typedef std::unique_ptr HandshakeState; +inline HandshakeState WrapHandshakeState(NoiseHandshakeState* s) { + return HandshakeState(s, noise_handshakestate_free); +} + +typedef std::unique_ptr DHState; +inline DHState WrapDHState(NoiseDHState* s) { + return DHState(s, noise_dhstate_free); +} + +DHState CloneDHState(const DHState& s); + +typedef std::unique_ptr CipherState; +inline CipherState WrapCipherState(NoiseCipherState* s) { + return CipherState(s, noise_cipherstate_free); +} + +// Encrypt the given string. +std::pair Encrypt(NoiseCipherState* cs, const std::string& plaintext); +// Decrypt the given string. +std::pair Decrypt(NoiseCipherState* cs, const std::string& ciphertext); + +} // namespace svr2::noise + +#endif // __SVR2_NOISE_NOISE_H__ diff --git a/enclave/noise/tests/encrypt_decrypt.cc b/enclave/noise/tests/encrypt_decrypt.cc new file mode 100644 index 0000000..95524de --- /dev/null +++ b/enclave/noise/tests/encrypt_decrypt.cc @@ -0,0 +1,100 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +//TESTDEP gtest +//TESTDEP noise +//TESTDEP noise-c +//TESTDEP noisewrap +//TESTDEP env +//TESTDEP util +//TESTDEP env/test +//TESTDEP env +//TESTDEP metrics +//TESTDEP proto +//TESTDEP protobuf-lite +//TESTDEP libsodium + +#include +#include +#include "noise/noise.h" +#include "env/env.h" +#include "util/log.h" +#include "proto/error.pb.h" +#include "util/cpu.h" + +namespace svr2::noise { + +class CipherStateTest : public ::testing::Test { + protected: + static void SetUpTestCase() { + env::Init(); + } + + void EncryptDecrypt(const std::string& plaintext, std::string* ciphertext_out, int type) { + std::array key = {1}; + NoiseCipherState* s1n; + NoiseCipherState* s2n; + ASSERT_EQ(NOISE_ERROR_NONE, noise_cipherstate_new_by_id(&s1n, type)); + ASSERT_EQ(NOISE_ERROR_NONE, noise_cipherstate_init_key(s1n, key.data(), key.size())); + ASSERT_EQ(NOISE_ERROR_NONE, noise_cipherstate_new_by_id(&s2n, type)); + ASSERT_EQ(NOISE_ERROR_NONE, noise_cipherstate_init_key(s2n, key.data(), key.size())); + noise::CipherState s1 = noise::WrapCipherState(s1n); + noise::CipherState s2 = noise::WrapCipherState(s2n); + auto [ciphertext, enc_err] = noise::Encrypt(s1n, plaintext); + ASSERT_EQ(error::OK, enc_err); + auto [computed_plaintext, dec_err] = noise::Decrypt(s2n, ciphertext); + ASSERT_EQ(error::OK, dec_err); + ASSERT_EQ(plaintext, computed_plaintext); + ciphertext_out->swap(ciphertext); + } +}; + +TEST_F(CipherStateTest, EncryptDecrypt) { + std::string ciphertext; + EncryptDecrypt("", &ciphertext, NOISE_CIPHER_CHACHAPOLY); + ASSERT_EQ(16, ciphertext.size()); + EncryptDecrypt("a", &ciphertext, NOISE_CIPHER_CHACHAPOLY); + ASSERT_EQ(17, ciphertext.size()); + + EncryptDecrypt("this is a test of the emergency broadcast system", &ciphertext, NOISE_CIPHER_CHACHAPOLY); + + std::string s; + + s.resize(65535-16, 'a'); + EncryptDecrypt(s, &ciphertext, NOISE_CIPHER_CHACHAPOLY); + ASSERT_EQ(ciphertext.size(), 65535); + + s.resize(65535-15, 'a'); + EncryptDecrypt(s, &ciphertext, NOISE_CIPHER_CHACHAPOLY); + ASSERT_EQ(ciphertext.size(), 65535-15+32); + + s.resize((65535-16)*10, 'a'); + EncryptDecrypt(s, &ciphertext, NOISE_CIPHER_CHACHAPOLY); + ASSERT_EQ(ciphertext.size(), 65535*10); +} + +TEST_F(CipherStateTest, BenchmarkChaChaPoly) { + std::string plaintext; + std::string ciphertext; + plaintext.resize(1 << 20, 'a'); + auto start = util::asm_rdtsc(); + int times = 100; + for (int i = 0; i < times; i++) { + EncryptDecrypt(plaintext, &ciphertext, NOISE_CIPHER_CHACHAPOLY); + } + LOG(INFO) << "took " << ((util::asm_rdtsc() - start) * 1.0 / (times * plaintext.size())) << " cycles/byte"; +} + +TEST_F(CipherStateTest, BenchmarkAesGcm) { + std::string plaintext; + std::string ciphertext; + plaintext.resize(1 << 20, 'a'); + auto start = util::asm_rdtsc(); + int times = 100; + for (int i = 0; i < times; i++) { + EncryptDecrypt(plaintext, &ciphertext, NOISE_CIPHER_AESGCM); + } + LOG(INFO) << "took " << ((util::asm_rdtsc() - start) * 1.0 / (times * plaintext.size())) << " cycles/byte"; +} + +} // namespace svr2::noise diff --git a/enclave/noisewrap/tests/noisewrap.cc b/enclave/noisewrap/tests/noisewrap.cc new file mode 100644 index 0000000..270d681 --- /dev/null +++ b/enclave/noisewrap/tests/noisewrap.cc @@ -0,0 +1,33 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +//TESTDEP gtest +//TESTDEP noise-c +//TESTDEP noisewrap +//TESTDEP util +//TESTDEP env +//TESTDEP env/test +//TESTDEP proto +//TESTDEP protobuf-lite +//TESTDEP libsodium + +#include +#include +#include "env/env.h" +#include "util/log.h" +#include +#include +#include "util/hex.h" + +namespace svr2 { + +TEST(NoiseWrap, RandomnessIsWrappedDeterministically) { + svr2::env::Init(); + std::array out; + ASSERT_EQ(NOISE_ERROR_NONE, noise_randstate_generate_simple(out.data(), out.size())); + LOG(INFO) << "RAND: " << util::ToHex(out); + uint8_t expect[8] = {0x4f, 0x6f, 0xa8, 0x48, 0x32, 0xaa, 0x7d, 0x32}; + ASSERT_EQ(0, memcmp(out.data(), expect, 8)); +} + +} // namespace svr2 diff --git a/enclave/noisewrap/wrap.cc b/enclave/noisewrap/wrap.cc new file mode 100644 index 0000000..2b56fa0 --- /dev/null +++ b/enclave/noisewrap/wrap.cc @@ -0,0 +1,15 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include "env/env.h" +#include "proto/error.pb.h" +#include "util/macros.h" + +extern "C" { + +// Wrap Noise's call to get randomness so it uses our enclave's random generator. +void __wrap_noise_rand_bytes(void* bytes, size_t size) { + CHECK(::svr2::error::OK == ::svr2::env::environment->RandomBytes(bytes, size)); +} + +} // extern "C" diff --git a/enclave/peerid/peerid.cc b/enclave/peerid/peerid.cc new file mode 100644 index 0000000..321cdd2 --- /dev/null +++ b/enclave/peerid/peerid.cc @@ -0,0 +1,49 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include +#include + +#include "peerid/peerid.h" +#include "sip/halfsiphash.h" +#include "util/log.h" +#include "metrics/metrics.h" +#include "util/hex.h" + +namespace svr2::peerid { + +static std::array zero_id = {0}; + +size_t PeerIDHasher::operator()(const PeerID& id) const { + return Hash(id.id_.data(), id.id_.size()); +} + +PeerID::PeerID(const uint8_t array[32]) { + std::copy(array, array+32, id_.begin()); +} +PeerID::PeerID() : id_({0}) {} +error::Error PeerID::FromString(const std::string& s) { + if (s.size() != id_.size()) { + return COUNTED_ERROR(Peers_InvalidID); + } + std::copy(s.begin(), s.end(), id_.begin()); + return error::OK; +} +bool PeerID::Valid() const { + // https://cr.yp.to/ecdh.html#validate + return id_ != zero_id; +} +void PeerID::ToString(std::string* s) const { + s->resize(32, 0); + std::copy(id_.begin(), id_.end(), s->begin()); +} +std::string PeerID::DebugString() const { + return util::PrefixToHex(id_, 4); +} + +std::ostream& operator<<(std::ostream& os, const PeerID& peer_id) { + os << peer_id.DebugString(); + return os; +} + +} // namespace svr2::peerid diff --git a/enclave/peerid/peerid.h b/enclave/peerid/peerid.h new file mode 100644 index 0000000..4741d5f --- /dev/null +++ b/enclave/peerid/peerid.h @@ -0,0 +1,54 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_PEERID_PEERID_H__ +#define __SVR2_PEERID_PEERID_H__ + +#include +#include +#include + +#include "context/context.h" +#include "proto/error.pb.h" +#include "sip/hasher.h" + +namespace svr2::peerid { + +class PeerID; + +class PeerIDHasher : public sip::Hasher { + public: + size_t operator()(const PeerID& id) const; +}; + +class PeerID { + public: + PeerID(PeerID&& moved) = default; + PeerID(const PeerID& copied) = default; + PeerID& operator=(const PeerID& other) = default; + PeerID(); // all zeros, invalid + PeerID(const uint8_t array[32]); + error::Error FromString(const std::string& s); + void ToString(std::string* s) const; + const std::array& Get() const { return id_; } + bool Valid() const; + bool operator==(const PeerID& other) const { return id_ == other.id_; } + bool operator!=(const PeerID& other) const { return id_ != other.id_; } + bool operator<(const PeerID& other) const { return id_ < other.id_; } + std::string DebugString() const; + std::string AsString() const { std::string out; ToString(&out); return out; } + + // Prints DebugString() to an ostream. Overload is acceptable because + // PeerID represents a value and DebugString() does not expose any implementation + // details of the object (https://google.github.io/styleguide/cppguide.html#Streams) + friend std::ostream& operator<<(std::ostream& os, const PeerID& peer_id); + + + private: + std::array id_; + friend class PeerIDHasher; +}; + +} // namespace svr2::peerid + +#endif // __SVR2_PEERID_PEERID_H__ diff --git a/enclave/peerid/tests/peerid.cc b/enclave/peerid/tests/peerid.cc new file mode 100644 index 0000000..1a3e0db --- /dev/null +++ b/enclave/peerid/tests/peerid.cc @@ -0,0 +1,107 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +//TESTDEP gtest +//TESTDEP sip +//TESTDEP sender +//TESTDEP env +//TESTDEP env/test +//TESTDEP util +//TESTDEP context +//TESTDEP metrics +//TESTDEP proto +//TESTDEP protobuf-lite +//TESTDEP libsodium + +#include +#include "peers/peers.h" +#include "env/env.h" +#include "util/log.h" +#include "proto/e2e.pb.h" +#include +#include +#include + +namespace svr2::peerid { + +class PeerIDTest : public ::testing::Test { + protected: + static void SetUpTestCase() { + env::Init(); + } +}; + +TEST_F(PeerIDTest, Valid) { + PeerID id; + ASSERT_FALSE(id.Valid()); + std::string more_valid = "12345678901234567890123456789012"; + ASSERT_EQ(error::OK, id.FromString(more_valid)); + ASSERT_TRUE(id.Valid()); +} + +TEST_F(PeerIDTest, FromString) { + PeerID id; + std::string valid = "12345678901234567890123456789012"; + ASSERT_EQ(error::OK, id.FromString(valid)); + std::array expected = { + '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '1', '2', '3', '4', '5', '6', + '7', '8', '9', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '1', '2', + }; + ASSERT_EQ(expected, id.Get()); + ASSERT_NE(error::OK, id.FromString("badstring")); + // We can set the string to invalid (all zeros), and FromString will still succeed. + ASSERT_EQ(error::OK, id.FromString(std::string("\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0", 32))); +} + +TEST_F(PeerIDTest, FromArray) { + uint8_t in[32] = {1}; + PeerID id(in); + std::array expected = { + 1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, + 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, + }; + ASSERT_EQ(expected, id.Get()); +} + +TEST_F(PeerIDTest, Equality) { + PeerID id1, id2; + std::string valid = "12345678901234567890123456789012"; + std::string valid2 = "00045678901234567890123456789012"; + ASSERT_EQ(error::OK, id1.FromString(valid)); + ASSERT_EQ(error::OK, id2.FromString(valid)); + ASSERT_TRUE(id1 == id2); + ASSERT_EQ(error::OK, id2.FromString(valid2)); + ASSERT_FALSE(id1 == id2); + ASSERT_EQ(error::OK, id1.FromString(valid2)); + ASSERT_TRUE(id1 == id2); +} + +TEST_F(PeerIDTest, DebugString) { + PeerID id; + ASSERT_EQ(id.DebugString(), "00000000"); + uint8_t in[32] = {1, 2, 3}; + id = PeerID(in); + ASSERT_EQ(id.DebugString(), "01020300"); +} + +TEST_F(PeerIDTest, Copy) { + PeerID id1; + std::string valid = "12345678901234567890123456789012"; + ASSERT_EQ(error::OK, id1.FromString(valid)); + PeerID id2 = id1; + ASSERT_TRUE(id1 == id2); +} + +TEST_F(PeerIDTest, Mapping) { + std::unordered_map map; + for (uint8_t i = 1; i <= 10; i++) { + uint8_t in[32] = {i}; + map[PeerID(in)] = i; + } + for (uint8_t i = 1; i <= 10; i++) { + uint8_t in[32] = {i}; + ASSERT_EQ(map[PeerID(in)], i); + } +} + +} // namespace svr2::peerid diff --git a/enclave/peers/peers.cc b/enclave/peers/peers.cc new file mode 100644 index 0000000..1e10003 --- /dev/null +++ b/enclave/peers/peers.cc @@ -0,0 +1,691 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include +#include +#include + +#include +#include + +#include "peers/peers.h" +#include "util/macros.h" +#include "env/env.h" +#include "sip/halfsiphash.h" +#include "util/endian.h" +#include "util/log.h" +#include "util/constant.h" +#include "sender/sender.h" +#include "metrics/metrics.h" + +// There's some mildly complicated locking going on between Peer and PeerManager +// objects to maintain the necessary invariants for smooth operation. +// In a multi-threaded environment, we want encryption/decryption operations to +// be able to utilize multiple cores. However, peer connections are by their +// nature serial, as each connection uses a serialized set of noise state for +// {en,de}cryption purposes. We've striven to make this locking as simple as +// possible, and we've come up with this: +// +// * The peer manager's lock protects its map (peers_) of peers. Peer objects +// may be added to this list, but are never removed. This is because... +// * Each Peer object contains a lock of its own, which serializes communication +// across that peer's established connection. +// +// In short, the PeerManager lock is simply for lookup, while the Peer lock +// is used for all encryption/decryption/etc. associated with a peer. This +// means that each peer/peer connection is effectively single-threaded, but +// if multiple messages are received from multiple peers, the enclave can +// process their encryption in parallel. + +namespace svr2::peers { + +static NoiseProtocolId peer_to_peer_protocol = { + .prefix_id = NOISE_PREFIX_STANDARD, + .pattern_id = NOISE_PATTERN_KK, + .dh_id = NOISE_DH_CURVE25519, + // We use ChaChaPoly for client communication, because it's easier on clients + // and the vast majority of client interaction is dominated by the DH key exchange, + // rather than the actual stream cipher. However, for peer-to-peer communication, + // we establish connections infrequently and then use the stream cipher a LOT. + // This is especially true during initial replication, when the entire database + // state needs to be encrypted/decrypted. Since we use a libsodium backend, we + // have access to hardware-accelerated AES, so we use that. + .cipher_id = NOISE_CIPHER_AESGCM, + .hash_id = NOISE_HASH_SHA256, + .hybrid_id = 0, +}; + +Peer::Peer(const peerid::PeerID& id, PeerManager* parent) + : id_(id), + handshake_(noise::WrapHandshakeState(nullptr)), + tx_(noise::WrapCipherState(nullptr)), + rx_(noise::WrapCipherState(nullptr)), + parent_(parent), + last_attestation_(0) {} + +error::Error Peer::Send( + context::Context* ctx, + const e2e::EnclaveToEnclaveMessage& msg) { + ACQUIRE_LOCK(mu_, ctx, lock_peer); + MEASURE_CPU(ctx, cpu_peer_encrypt); + if (InternalCurrentState() != PEER_CONNECTED) { + return COUNTED_ERROR(Peers_SendBeforeConnect); + } + auto enclave_message = ctx->Protobuf(); + auto send = enclave_message->mutable_peer_message(); + std::string serialized; + if (!msg.SerializeToString(&serialized)) { + return COUNTED_ERROR(Peers_EncryptSerialize); + } + auto [ciphertext, err] = noise::Encrypt(tx_.get(), serialized); + if (err != error::OK) { + // An encryption error probably means bad noise state, which is unrecoverable. + InternalDisconnect(); + SendRst(ctx, id_); + return err; + } + send->mutable_data()->swap(ciphertext); + id_.ToString(send->mutable_peer_id()); + sender::Send(*enclave_message); + return error::OK; +} + +error::Error Peer::Recv( + context::Context* ctx, + const PeerMessage& msg, + e2e::EnclaveToEnclaveMessage** decoded) { + ACQUIRE_LOCK(mu_, ctx, lock_peer); + switch (msg.inner_case()) { + case PeerMessage::kSynack: { + MEASURE_CPU(ctx, cpu_peer_connect2); + if (InternalCurrentState() != PEER_CONNECTING) { + InternalDisconnect(); + SendRst(ctx, id_); + return COUNTED_ERROR(Peers_SynAckNotConnecting); + } + error::Error err = FinishConnection(ctx, msg.synack(), decoded); + if (err != error::OK) { + InternalDisconnect(); + SendRst(ctx, id_); + return err; + } + } return error::OK; + case PeerMessage::kData: { + MEASURE_CPU(ctx, cpu_peer_decrypt); + if (InternalCurrentState() != PEER_CONNECTED) { + return COUNTED_ERROR(Peers_DataNotConnected); + } + auto [plaintext, err] = noise::Decrypt(rx_.get(), msg.data()); + if (err != error::OK) { + // A decryption error probably means bad noise state, which is unrecoverable. + InternalDisconnect(); + SendRst(ctx, id_); + return err; + } + auto e2e_message = ctx->Protobuf(); + if (!e2e_message->ParseFromString(plaintext)) { + return COUNTED_ERROR(Peers_DecryptParse); + } + if (e2e_message->inner_case() == e2e::EnclaveToEnclaveMessage::kAttestationUpdate) { + auto err = CheckNextAttestation(e2e_message->attestation_update()); + if (err != error::OK) { + LOG(WARNING) << "Peer " << id_ << " attestation update failure: " << err; + InternalDisconnect(); + SendRst(ctx, id_); + } + return err; + } + *decoded = e2e_message; + } return error::OK; + case PeerMessage::kRst: + LOG(INFO) << "Received RST from " << id_; + InternalDisconnect(); + return error::OK; + case PeerMessage::kSyn: + CHECK(nullptr == "PeerManager.RecvFromPeer should have called Accept, not Recv"); + default: + return COUNTED_ERROR(Peers_InvalidMsg); + } +} + +error::Error Peer::FinishConnection( + context::Context* ctx, + const std::string& synack, + e2e::EnclaveToEnclaveMessage** decoded) { + if (NOISE_ACTION_READ_MESSAGE != noise_handshakestate_get_action(handshake_.get())) { + return COUNTED_ERROR(Peers_HandshakeState); + } + + noise::HandshakeState local_handshake = noise::WrapHandshakeState(nullptr); + local_handshake.swap(handshake_); + + e2e::ConnectRequest* conn = ctx->Protobuf(); + if (!conn->ParseFromString(synack)) { + return COUNTED_ERROR(Peers_FinishParseHandshake); + } + auto remote_attestation = conn->attestation(); + auto ts = parent_->CurrentTime(); + auto [att_key, att_err] = env::environment->Attest( + ts, + remote_attestation.evidence(), + remote_attestation.endorsements()); + if(att_err != error::OK) { + return att_err; + } + if(!util::ConstantTimeEquals(att_key, this->ID().Get())) { + return error::Peers_FinishIDMismatch; + } + + NoiseBuffer buf = noise::BufferInputFromString(conn->mutable_handshake()); + if (NOISE_ERROR_NONE != noise_handshakestate_read_message(local_handshake.get(), &buf, nullptr)) { + return COUNTED_ERROR(Peers_FinishReadHandshake); + } + if (NOISE_ACTION_SPLIT != noise_handshakestate_get_action(local_handshake.get())) { + return COUNTED_ERROR(Peers_HandshakeState); + } + NoiseCipherState* tx; + NoiseCipherState* rx; + if (NOISE_ERROR_NONE != noise_handshakestate_split(local_handshake.get(), &tx, &rx)) { + return COUNTED_ERROR(Peers_FinishSplit); + } + + tx_.reset(tx); + rx_.reset(rx); + auto e2e_message = ctx->Protobuf(); + e2e_message->set_connected(true); + *decoded = e2e_message; + last_attestation_ = ts; + return error::OK; +} + +error::Error Peer::Connect( + context::Context* ctx, + const noise::DHState& priv, + const e2e::Attestation& attestation) { + ACQUIRE_LOCK(mu_, ctx, lock_peer); + return Peer::InternalConnect(ctx, priv, attestation); +} + +std::pair Peer::MaybeConnect( + context::Context* ctx, + const noise::DHState& priv, + const e2e::Attestation& attestation) { + ACQUIRE_LOCK(mu_, ctx, lock_peer); + switch (InternalCurrentState()) { + case PEER_CONNECTING: + case PEER_CONNECTED: + return std::make_pair(false, error::OK); + case PEER_DISCONNECTED: + default: + return std::make_pair(true, InternalConnect(ctx, priv, attestation)); + } +} + +error::Error Peer::InternalConnect( + context::Context* ctx, + const noise::DHState& priv, + const e2e::Attestation& attestation) { + MEASURE_CPU(ctx, cpu_peer_connect); + RETURN_IF_ERROR(Reset(priv, NOISE_ROLE_INITIATOR)); + CHECK(handshake_.get()); + + // Take away our class state for the duration of this call, so that if something goes + // wrong we don't have a misbehaving handshake lying around. + noise::HandshakeState local_handshake = noise::WrapHandshakeState(nullptr); + local_handshake.swap(handshake_); + if (NOISE_ACTION_WRITE_MESSAGE != noise_handshakestate_get_action(local_handshake.get())) { + return COUNTED_ERROR(Peers_HandshakeState); + } + + e2e::ConnectRequest* conn = ctx->Protobuf(); + conn->mutable_attestation()->CopyFrom(attestation); + + // Create the initial Noise initiator handshake request buffer in [conn->handshake]. + conn->mutable_handshake()->resize(noise::HANDSHAKE_INIT_SIZE, '\0'); + NoiseBuffer buf; + noise_buffer_set_output( + buf, + reinterpret_cast(const_cast(conn->mutable_handshake()->data())), + conn->mutable_handshake()->size()); + if (NOISE_ERROR_NONE != noise_handshakestate_write_message(local_handshake.get(), &buf, nullptr)) { + return COUNTED_ERROR(Peers_ConnectWriteHandshake); + } + conn->mutable_handshake()->resize(buf.size); + + // Create the [encoded] output message by serializing [conn]. + auto enclave_message = ctx->Protobuf(); + auto send = enclave_message->mutable_peer_message(); + id_.ToString(send->mutable_peer_id()); + if (!conn->SerializeToString(send->mutable_syn())) { + return COUNTED_ERROR(Peers_ConnectSerializeHandshake); + } + + // Give back the (well-behaved) handshake state. + local_handshake.swap(handshake_); + + sender::Send(*enclave_message); + return error::OK; +} + +error::Error Peer::Accept( + context::Context* ctx, + const noise::DHState& priv, + const e2e::Attestation& attestation, + const std::string& syn, + e2e::EnclaveToEnclaveMessage** decoded) { + ACQUIRE_LOCK(mu_, ctx, lock_peer); + MEASURE_CPU(ctx, cpu_peer_accept); + RETURN_IF_ERROR(Reset(priv, NOISE_ROLE_RESPONDER)); + CHECK(handshake_.get()); + + // Take away our class state for the duration of this call, so that if something goes + // wrong we don't have a misbehaving handshake lying around. + noise::HandshakeState local_handshake = noise::WrapHandshakeState(nullptr); + local_handshake.swap(handshake_); + + if (NOISE_ACTION_READ_MESSAGE != noise_handshakestate_get_action(local_handshake.get())) { + return COUNTED_ERROR(Peers_HandshakeState); + } + + e2e::ConnectRequest* conn_request = ctx->Protobuf(); + if (!conn_request->ParseFromString(syn)) { + return COUNTED_ERROR(Peers_AcceptParseHandshake); + } + + // validate the attestation + auto remote_attestation = conn_request->attestation(); + auto ts = parent_->CurrentTime(); + auto [att_key, att_err] = env::environment->Attest( + ts, + remote_attestation.evidence(), + remote_attestation.endorsements()); + if(att_err != error::OK) { + return att_err; + } + if(!util::ConstantTimeEquals(att_key, this->ID().Get())) { + return error::Peers_AcceptIDMismatch; + } + + NoiseBuffer read_buf = noise::BufferInputFromString(conn_request->mutable_handshake()); + int err = 0; + if (NOISE_ERROR_NONE != (err = noise_handshakestate_read_message(local_handshake.get(), &read_buf, nullptr))) { + return COUNTED_ERROR(Peers_AcceptReadHandshake); + } + if (NOISE_ACTION_WRITE_MESSAGE != noise_handshakestate_get_action(local_handshake.get())) { + return COUNTED_ERROR(Peers_HandshakeState); + } + auto conn_response = ctx->Protobuf(); + conn_response->mutable_attestation()->CopyFrom(attestation); + conn_response->mutable_handshake()->resize(noise::HANDSHAKE_INIT_SIZE, '\0'); + NoiseBuffer write_buf = noise::BufferOutputFromString(conn_response->mutable_handshake()); + if (NOISE_ERROR_NONE != noise_handshakestate_write_message(local_handshake.get(), &write_buf, nullptr)) { + return COUNTED_ERROR(Peers_AcceptWriteHandshake); + } + conn_response->mutable_handshake()->resize(write_buf.size, 0); + auto enclave_message = ctx->Protobuf(); + auto send = enclave_message->mutable_peer_message(); + id_.ToString(send->mutable_peer_id()); + if (!conn_response->SerializeToString(send->mutable_synack())) { + return COUNTED_ERROR(Peers_AcceptSerializeHandshake); + } + if (NOISE_ACTION_SPLIT != noise_handshakestate_get_action(local_handshake.get())) { + return COUNTED_ERROR(Peers_HandshakeState); + } + + NoiseCipherState* tx; + NoiseCipherState* rx; + if (NOISE_ERROR_NONE != noise_handshakestate_split(local_handshake.get(), &tx, &rx)) { + return COUNTED_ERROR(Peers_AcceptSplit); + } + tx_.reset(tx); + rx_.reset(rx); + + auto e2e_message = ctx->Protobuf(); + e2e_message->set_connected(true); + *decoded = e2e_message; + last_attestation_ = ts; + sender::Send(*enclave_message); + return error::OK; +} + +void Peer::Disconnect(context::Context* ctx) { + ACQUIRE_LOCK(mu_, ctx, lock_peer); + InternalDisconnect(); + Peer::SendRst(ctx, id_); +} + +PeerState Peer::CurrentState(context::Context* ctx) const { + ACQUIRE_LOCK(mu_, ctx, lock_peer); + return InternalCurrentState(); +} + +PeerState Peer::InternalCurrentState() const { + if (handshake_.get() != nullptr) { + return PEER_CONNECTING; + } + if (tx_.get() != nullptr && rx_.get() != nullptr) { + return PEER_CONNECTED; + } + return PEER_DISCONNECTED; +} + +void Peer::InternalDisconnect() { + handshake_.reset(nullptr); + tx_.reset(nullptr); + rx_.reset(nullptr); + last_attestation_ = 0; +} + +void Peer::SendRst(context::Context* ctx, const peerid::PeerID& id) { + auto enclave_message = ctx->Protobuf(); + auto send = enclave_message->mutable_peer_message(); + id.ToString(send->mutable_peer_id()); + send->set_rst(true); + sender::Send(*enclave_message); +} + +error::Error Peer::Reset(const noise::DHState& priv, int noise_role) { + InternalDisconnect(); + NoiseHandshakeState* hsp; + if (NOISE_ERROR_NONE != noise_handshakestate_new_by_id(&hsp, &peer_to_peer_protocol, noise_role)) { + return COUNTED_ERROR(Peers_HandshakeState); + } + noise::HandshakeState hs = noise::WrapHandshakeState(hsp); + + if (NOISE_ERROR_NONE != noise_dhstate_copy( + noise_handshakestate_get_local_keypair_dh(hsp), + priv.get())) { + return COUNTED_ERROR(Peers_CopyDHState); + } + if (NOISE_ERROR_NONE != noise_dhstate_set_public_key( + noise_handshakestate_get_remote_public_key_dh(hsp), + id_.Get().data(), + id_.Get().size())) { + return COUNTED_ERROR(Peers_SetRemotePublicKey); + } + if (NOISE_ERROR_NONE != noise_handshakestate_start(hsp)) { + return COUNTED_ERROR(Peers_HandshakeStart); + } + + handshake_.swap(hs); + return error::OK; +} + +error::Error Peer::CheckNextAttestation(const e2e::Attestation& a) { + auto now = parent_->CurrentTime(); + auto [key, err] = env::environment->Attest(now, a.evidence(), a.endorsements()); + RETURN_IF_ERROR(err); + if (!util::ConstantTimeEquals(key, id_.Get())) { + LOG(ERROR) << "Peer " << id_ << " sent attestation with incorrect key"; + return COUNTED_ERROR(Peers_AttestationKeyChanged); + } + LOG(DEBUG) << "Peer " << id_ << " re-attested at " << now; + last_attestation_ = now; + return error::OK; +} + +void Peer::MaybeDisconnectIfAttestationTooOld(context::Context* ctx, util::UnixSecs now, util::UnixSecs attestation_timeout) { + ACQUIRE_LOCK(mu_, ctx, lock_peer); + auto state = InternalCurrentState(); + if (// If we're already disconnected ... + state == PEER_DISCONNECTED || + // ... or our attestation timestamp is in a good range ... + (now <= last_attestation_ + attestation_timeout && now >= last_attestation_ - attestation_timeout) || + // ... or we're connecting and we haven't yet received a synack with an attestation ... + (state == PEER_CONNECTING && last_attestation_ == 0)) { + // ... then there's no need for us to disconnect due to attestation timestamp. + return; + } + LOG(WARNING) << "Attestation for " << id_ << " too old (ts=" << last_attestation_ << ", now=" << now << "), disconnecting"; + InternalDisconnect(); + SendRst(ctx, id_); +} + +void Peer::PopulateConnectionStatus(context::Context* ctx, ConnectionStatus* status) const { + ACQUIRE_LOCK(mu_, ctx, lock_peer); + status->set_state(InternalCurrentState()); + status->set_last_attestation_unix_secs(last_attestation_); +} + +PeerManager::PeerManager() + : dhstate_(noise::WrapDHState(nullptr)), + init_success_(false), + time_(0) {} +PeerManager::~PeerManager() {} + +PeerState PeerManager::PeerState(context::Context* ctx, const peerid::PeerID& id) const { + ACQUIRE_LOCK(mu_, ctx, lock_peermanager); + auto finder = peers_.find(id); + if (finder == peers_.end()) { + return PEER_DISCONNECTED; + } + return finder->second->CurrentState(ctx); +} + +error::Error PeerManager::Init(context::Context* ctx) { + NoiseDHState* dhstate; + if (NOISE_ERROR_NONE != noise_dhstate_new_by_id(&dhstate, peer_to_peer_protocol.dh_id)) { + return COUNTED_ERROR(Peers_NewKey); + } + noise::DHState dh = noise::WrapDHState(dhstate); + if (NOISE_ERROR_NONE != noise_dhstate_generate_keypair(dhstate)) { + return COUNTED_ERROR(Peers_NewKeyGenerate); + } + env::PublicKey public_key{}; + if (NOISE_ERROR_NONE != noise_dhstate_get_public_key(dhstate, public_key.data(), sizeof(public_key))) { + return COUNTED_ERROR(Peers_NewKeyPublic); + } + + auto [evidence_and_endorsements, err] = env::environment->Evidence(public_key, enclaveconfig::RaftGroupConfig()); + RETURN_IF_ERROR(err); + + ACQUIRE_LOCK(mu_, ctx, lock_peermanager); + if (init_success_.exchange(true)) { + return COUNTED_ERROR(Peers_ReInit); + } + dhstate_.swap(dh); + id_ = peerid::PeerID(public_key.data()); + most_recent_attestation_.CopyFrom(evidence_and_endorsements); + return error::OK; +} + +static peerid::PeerID invalid_id; + +const peerid::PeerID& PeerManager::ID() const { + if (!init_success_.load()) { return invalid_id; } + return id_; +} + +error::Error PeerManager::RefreshAttestation(context::Context* ctx) { + auto [evidence_and_endorsements, err] = env::environment->Evidence(ID().Get(), enclaveconfig::RaftGroupConfig()); + if (err != error::OK) { + COUNTER(peers, attestation_refresh_failure)->Increment(); + return err; + } + + ACQUIRE_LOCK(mu_, ctx, lock_peermanager); + COUNTER(peers, attestation_refresh_success)->Increment(); + most_recent_attestation_ = evidence_and_endorsements; + auto msg = ctx->Protobuf(); + *msg->mutable_attestation_update() = most_recent_attestation_; + + LOG(DEBUG) << "Sending refreshed attestation to peers"; + for (auto iter = peers_.begin(); iter != peers_.end(); ++iter) { + if (iter->second->CurrentState(ctx) == PEER_CONNECTED) { + auto err = iter->second->Send(ctx, *msg); + LOG(VERBOSE) << "Sent refreshed attestation to " << iter->first << ": " << err; + if (err != error::OK) { + LOG(WARNING) << "Sending most recent attestation to " << iter->first << " failed: " << err; + } + } + } + return error::OK; +} + +Peer* PeerManager::CreatePeer(context::Context* ctx, const peerid::PeerID& id) { + ACQUIRE_LOCK(mu_, ctx, lock_peermanager); + auto finder = peers_.find(id); + if (finder != peers_.end()) { + return finder->second.get(); + } + auto [iter, _] = peers_.emplace(id, std::make_unique(id, this)); + GAUGE(peers, peers)->Set(peers_.size()); + return iter->second.get(); +} + +Peer* PeerManager::GetPeer(context::Context* ctx, const peerid::PeerID& id) const { + ACQUIRE_LOCK(mu_, ctx, lock_peermanager); + auto finder = peers_.find(id); + if (finder != peers_.end()) { + return finder->second.get(); + } + return nullptr; +} + +Peer* PeerManager::GetPeerOrRst(context::Context* ctx, const peerid::PeerID& id) const { + ACQUIRE_LOCK(mu_, ctx, lock_peermanager); + auto finder = peers_.find(id); + if (finder != peers_.end()) { + return finder->second.get(); + } + Peer::SendRst(ctx, id); + return nullptr; +} + + +std::pair PeerManager::ConnectionArgs(context::Context* ctx) { + CHECK(init_success_.load()); + ACQUIRE_LOCK(mu_, ctx, lock_peermanager); + auto att = ctx->Protobuf(); + *att = most_recent_attestation_; + return std::make_pair(noise::CloneDHState(dhstate_), att); +} + +error::Error PeerManager::ConnectToPeer( + context::Context* ctx, + const peerid::PeerID& to) { + if (!init_success_.load()) { + return COUNTED_ERROR(Peers_NoInit); + } + Peer* peer = CreatePeer(ctx, to); + auto [dhstate, most_recent_attestation] = ConnectionArgs(ctx); + RETURN_IF_ERROR(peer->Connect(ctx, dhstate, *most_recent_attestation)); + LOG(INFO) << ID() << " connecting to new peer " << to; + return error::OK; +} + +error::Error PeerManager::MaybeConnectToPeer( + context::Context* ctx, + const peerid::PeerID& to) { + if (!init_success_.load()) { + return COUNTED_ERROR(Peers_NoInit); + } + Peer* peer = CreatePeer(ctx, to); + auto [dhstate, most_recent_attestation] = ConnectionArgs(ctx); + auto [started_connection, err] = peer->MaybeConnect(ctx, dhstate, *most_recent_attestation); + RETURN_IF_ERROR(err); + if (started_connection) { + LOG(INFO) << ID() << " connecting to new peer " << to; + } + return error::OK; +} + +error::Error PeerManager::ResetPeer( + context::Context* ctx, + const peerid::PeerID& to) { + if (!init_success_.load()) { + return COUNTED_ERROR(Peers_NoInit); + } + Peer* peer = GetPeerOrRst(ctx, to); + if (peer == nullptr) { + return COUNTED_ERROR(Peers_ResetMissingPeer); + } + peer->Disconnect(ctx); + return error::OK; +} + +error::Error PeerManager::SendToPeer( + context::Context* ctx, + const peerid::PeerID& to, + const e2e::EnclaveToEnclaveMessage& msg) { + if (!init_success_.load()) { + return COUNTED_ERROR(Peers_NoInit); + } + if (msg.connected()) { + return COUNTED_ERROR(Peers_SendConnect); + } + Peer* peer = GetPeerOrRst(ctx, to); + if (peer == nullptr) { + return COUNTED_ERROR(Peers_SendBeforeConnect); + } + return peer->Send(ctx, msg); +} + +error::Error PeerManager::RecvFromPeer( + context::Context* ctx, + const PeerMessage& msg, + e2e::EnclaveToEnclaveMessage** decoded) { + if (!init_success_.load()) { + return COUNTED_ERROR(Peers_NoInit); + } + *decoded = nullptr; + peerid::PeerID from; + RETURN_IF_ERROR(from.FromString(msg.peer_id())); + if (msg.inner_case() == PeerMessage::kSyn) { + Peer* peer = CreatePeer(ctx, from); + auto [dhstate, most_recent_attestation] = ConnectionArgs(ctx); + RETURN_IF_ERROR(peer->Accept(ctx, dhstate, *most_recent_attestation, msg.syn(), decoded)); + LOG(INFO) << ID() << " accepted new peer " << from; + return error::OK; + } + + Peer* peer = msg.inner_case() == PeerMessage::kRst + ? GetPeer(ctx, from) + : GetPeerOrRst(ctx, from); + + if (peer == nullptr) { + return COUNTED_ERROR(Peers_RecvBeforeConnect); + } + return peer->Recv(ctx, msg, decoded); +} + +std::set PeerManager::ConnectedPeers(context::Context* ctx) const { + ACQUIRE_LOCK(mu_, ctx, lock_peermanager); + std::set out; + for (auto iter = peers_.cbegin(); iter != peers_.cend(); ++iter) { + if (iter->second->CurrentState(ctx) == PEER_CONNECTED) { + out.insert(iter->first); + } + } + return out; +} + +std::set PeerManager::AllPeers(context::Context* ctx) const { + ACQUIRE_LOCK(mu_, ctx, lock_peermanager); + std::set out; + for (auto iter = peers_.cbegin(); iter != peers_.cend(); ++iter) { + out.insert(iter->first); + } + return out; +} + +void PeerManager::SetPeerAttestationTimestamp(context::Context* ctx, util::UnixSecs secs, util::UnixSecs attestation_timeout) { + auto old_secs = time_.exchange(secs); + if (old_secs == secs) { + return; + } else if (old_secs > secs) { + LOG(WARNING) << "PeerManager timestamp went backwards: " << old_secs << " -> " << secs; + } + ACQUIRE_LOCK(mu_, ctx, lock_peermanager); + for (auto iter = peers_.begin(); iter != peers_.end(); ++iter) { + iter->second->MaybeDisconnectIfAttestationTooOld(ctx, secs, attestation_timeout); + } +} + +void PeerManager::PeerStatus(context::Context* ctx, const peerid::PeerID& id, ConnectionStatus* status) const { + auto peer = GetPeer(ctx, id); + if (peer == nullptr) { return; } + peer->PopulateConnectionStatus(ctx, status); +} + +} // namespace svr2::remote diff --git a/enclave/peers/peers.h b/enclave/peers/peers.h new file mode 100644 index 0000000..040f455 --- /dev/null +++ b/enclave/peers/peers.h @@ -0,0 +1,261 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_PEERS_PEERS_H__ +#define __SVR2_PEERS_PEERS_H__ + +#include +#include +#include +#include +#include +#include + +#include "context/context.h" +#include "util/macros.h" +#include "proto/error.pb.h" +#include "proto/msgs.pb.h" +#include "proto/e2e.pb.h" +#include "sip/hasher.h" +#include "noise/noise.h" +#include "peerid/peerid.h" +#include "groupclock/groupclock.h" +#include "util/mutex.h" + +// Within the peer manager, peers can make the following state transitions. The normal +// transition paths are: +// +// Connect: DISCONNECTED -> CONNECTING -> CONNECTED +// Accept: DISCONNECTED -> CONNECTED +// +// Nowever, note that: +// +// - From the CONNECTING/CONNECTED states, one can enter the DISCONNECTED +// state should an error be encountered +// - From any state, one can enter CONNECTED by receiving a SYN and sending +// a SYN/ACK. +// - From DISCONNECTED, one can enter CONNECTING by sending a SYN. +// +// By utilizing these mechanisms, we should be able to re-establish good connections +// should any connection state become invalid. +// +// ┌────────────────┐ error / recv:RST ┌────────────────┐ +// │ │◄────────────────────┤ │ +// │ DISCONNECTED │ │ CONNECTING │ +// │ ├────────────────────►│ │ +// └───────────┬────┘ send:SYN └────┬────────┬──┘ +// ▲ │ │ │ +// │ │ │recv: │recv:SYN +// │ │ │SYNACK │send:SYNACK +// │ │ │ │ +// │ │ recv:SYN ▼ ▼ +// │ │ send:SYNACK ┌────────────────┐ +// │ └─────────────────────────►│ │ +// │ │ CONNECTED │ +// └────────────────────────────────┤ │ +// error / recv:RST └────────────────┘ +// +// (made with asciiflow) + +namespace svr2::peers { + +class PeerManager; + +// Encapsulates the state for a single remote peer. +class Peer { + public: + DELETE_COPY_AND_ASSIGN(Peer); + Peer(const peerid::PeerID& id, PeerManager* parent); + + const peerid::PeerID& ID() const { return id_; } + error::Error Send( + context::Context* ctx, + const e2e::EnclaveToEnclaveMessage& msg) EXCLUDES(mu_); + error::Error Recv( + context::Context* ctx, + const PeerMessage& msg, + e2e::EnclaveToEnclaveMessage** decoded) EXCLUDES(mu_); + + // Connect is called on a newly created peer to request establishment of a new + // connection to that remote party. + error::Error Connect( + context::Context* ctx, + const noise::DHState& priv, + const e2e::Attestation& attestation) EXCLUDES(mu_); + // MaybeConnect is called when we're not sure if we're connected already + // or not. It won't disrupt an existing connection, but it will establish + // a new one. Returns a bool that says whether we attempted to start + // the connection or not. + std::pair MaybeConnect( + context::Context* ctx, + const noise::DHState& priv, + const e2e::Attestation& attestation) EXCLUDES(mu_); + // Accept is called on a newly created peer to request establishment of a + // remote-requested connection to that remote party. + error::Error Accept( + context::Context* ctx, + const noise::DHState& priv, + const e2e::Attestation& attestation, + const std::string& syn, + e2e::EnclaveToEnclaveMessage** decoded) EXCLUDES(mu_); + // Disconnect the peer. + void Disconnect(context::Context* ctx) EXCLUDES(mu_); + // Disconnect the peer if its attestation timestamp is out of date. + void MaybeDisconnectIfAttestationTooOld(context::Context* ctx, util::UnixSecs now, util::UnixSecs attestation_timeout) EXCLUDES(mu_); + + PeerState CurrentState(context::Context* ctx) const EXCLUDES(mu_); + void PopulateConnectionStatus(context::Context* ctx, ConnectionStatus* status) const EXCLUDES(mu_); + + // Send a `rst` to the given peer ID. + static void SendRst(context::Context* ctx, const peerid::PeerID& id) EXCLUDES(mu_); + + private: + // Resets state to DISCONNECTED. + void InternalDisconnect() REQUIRES(mu_); + PeerState InternalCurrentState() const REQUIRES(mu_); + + // Connect is called on a newly created peer to request establishment of a new + // connection to that remote party. + error::Error InternalConnect( + context::Context* ctx, + const noise::DHState& priv, + const e2e::Attestation& attestation) REQUIRES(mu_); + + // FinishConnection is called by Recv when [state==CONNECTING] to complete the handshake. + error::Error FinishConnection( + context::Context* ctx, + const std::string& synack, + e2e::EnclaveToEnclaveMessage** decoded) REQUIRES(mu_); + + // Reset to a state where we have a valid handshake_. + error::Error Reset( + const noise::DHState& priv, + int noise_role) REQUIRES(mu_); + + error::Error CheckNextAttestation(const e2e::Attestation& a) REQUIRES(mu_); + + const peerid::PeerID id_; + mutable util::mutex mu_; + noise::HandshakeState handshake_ GUARDED_BY(mu_); + noise::CipherState tx_ GUARDED_BY(mu_); + noise::CipherState rx_ GUARDED_BY(mu_); + const PeerManager* const parent_; + util::UnixSecs last_attestation_ GUARDED_BY(mu_); +}; + +// PeerManager allows messages to be sent to and received from peers. +// +// Connecting to a new peer: +// +// Connector Accepter +// ---------------------------------------------------------- +// ConnectToPeer(accepter) +// - encoded = handshake request -> RecvFromPeer(connector, msg.syn) +// - decoded = e2e.connect +// RecvFromPeer(accepter, msg) <- - encoded = handshake response +// - decoded = e2e.connect +// - encoded = NULL +// +// The connector's first message contains the most recent attestation proof +// for the connector's communication public key, along with a Noise handshake +// for that key. +// +// The accepter's first message contains its attestation proof for its +// communication public key, as well as the noise handshake completion for +// this session. The [decoded] message that comes out will have the +// [e2e.connect] flag set when the handshake is complete and the session +// is considered usable. +// +// After a client has connected, this manager can handle received messages +// by passing them to RecvFromPeer and handling the resulting [decoded] +// message, and can send messages by passing them through SendToPeer, then +// sending the resulting EnclaveMessage up to the host for processing. +class PeerManager { + public: + DELETE_COPY_AND_ASSIGN(PeerManager); + PeerManager(); + ~PeerManager(); + + error::Error Init(context::Context* ctx) EXCLUDES(mu_); + + error::Error RefreshAttestation(context::Context* ctx); + + // ConnectToPeer requests that a new connection be established to the given + // PeerID. This will replace any existing connections that might exist + // with that peer with a new connection. + error::Error ConnectToPeer( + context::Context* ctx, + const peerid::PeerID& to); + + // Try to establish a connection to [to] if one doesn't already exist. + // If we're already connected or already attempting to connect, does nothing + // and returns success. + error::Error MaybeConnectToPeer( + context::Context* ctx, + const peerid::PeerID& to); + + // ResetPeer disconnects a peer and sends it an RST. + error::Error ResetPeer( + context::Context* ctx, + const peerid::PeerID& to); + + // SendToPeer takes in a serialized protobuf to send to [to]. If + // [msg.connect] is set, then this is requesting a new connection to [to] + // rather than sending on an existing channel. + // Note: does not actually send the message in question, just encodes it. + error::Error SendToPeer( + context::Context* ctx, + const peerid::PeerID& to, + const e2e::EnclaveToEnclaveMessage& msg); + + // RecvFromPeer takes in a PeerMessage and decodes it. If that + // message contains a EnclaveToEnclaveMessage, that message is instantiated + // in the provided [arena] and returned as [*decoded]. If not, [*decoded] + // will be NULL. If [*encoded] is not null, it should be sent up to the + // host. + // If this message establishes a connection, [*decoded.connect] will be set. + error::Error RecvFromPeer( + context::Context* ctx, + const PeerMessage& msg, + e2e::EnclaveToEnclaveMessage** decoded); + + // Returns the local identifier (public key) that remote peers use to connect + // to this peer manager. + const peerid::PeerID& ID() const NO_THREAD_SAFETY_ANALYSIS; + + // Get the current state of a peer ID. + PeerState PeerState(context::Context* ctx, const peerid::PeerID& id) const; + + std::set ConnectedPeers(context::Context* ctx) const; + std::set AllPeers(context::Context* ctx) const; + void PeerStatus(context::Context* ctx, const peerid::PeerID& id, ConnectionStatus* status) const; + + void SetPeerAttestationTimestamp(context::Context* ctx, util::UnixSecs secs, util::UnixSecs attestation_timeout) EXCLUDES(mu_); + + util::UnixSecs CurrentTime() const { return time_.load(); } + + private: + std::pair ConnectionArgs(context::Context* ctx) EXCLUDES(mu_); + + // CreatePeer returns a peer for the given ID, creating it if necessary. + Peer* CreatePeer(context::Context* ctx, const peerid::PeerID& id) EXCLUDES(mu_); + // GetPeer returns the peer associated with the given ID + Peer* GetPeer(context::Context* ctx, const peerid::PeerID& id) const EXCLUDES(mu_); + // GetPeerOrRst returns the peer associated with the given ID, sending + // a RST to that peer if it doesn't exist. + Peer* GetPeerOrRst(context::Context* ctx, const peerid::PeerID& id) const EXCLUDES(mu_); + + mutable util::mutex mu_; + // To simplify multi-threaded logic, a peer once added to `peers_` will + // never be removed. + std::unordered_map, peerid::PeerIDHasher> peers_ GUARDED_BY(mu_); + noise::DHState dhstate_ GUARDED_BY(mu_); + peerid::PeerID id_ GUARDED_BY(mu_); + e2e::Attestation most_recent_attestation_ GUARDED_BY(mu_); + std::atomic init_success_; + std::atomic time_; +}; + +} // namespace svr2::peers + +#endif // __SVR2_PEERS_PEERS_H__ diff --git a/enclave/peers/tests/peermanager.cc b/enclave/peers/tests/peermanager.cc new file mode 100644 index 0000000..fd60cd4 --- /dev/null +++ b/enclave/peers/tests/peermanager.cc @@ -0,0 +1,207 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +//TESTDEP gtest +//TESTDEP context +//TESTDEP noise +//TESTDEP noise-c +//TESTDEP noisewrap +//TESTDEP peerid +//TESTDEP sip +//TESTDEP sender +//TESTDEP env +//TESTDEP env/test +//TESTDEP util +//TESTDEP metrics +//TESTDEP proto +//TESTDEP protobuf-lite +//TESTDEP libsodium + +#include +#include "peers/peers.h" +#include "env/env.h" +#include "env/test/test.h" +#include "util/log.h" +#include "proto/e2e.pb.h" +#include +#include + +namespace svr2::peers { + +#define ATTESTATION_TIMEOUT 3600 + +class PeerManagerTest : public ::testing::Test { + protected: + static void SetUpTestCase() { + env::Init(); + } + + e2e::Attestation attestation; + context::Context ctx; + + PeerMessage* FromEnclaveMessage(const EnclaveMessage& msg, const peerid::PeerID& from) { + auto out = ctx.Protobuf(); + out->MergeFrom(msg.peer_message()); + from.ToString(out->mutable_peer_id()); + return out; + } + + void SetUp() { + mgr1 = std::make_unique(); + mgr2 = std::make_unique(); + ASSERT_EQ(error::OK, mgr1->Init(&ctx)); + ASSERT_EQ(error::OK, mgr2->Init(&ctx)); + mgr1->SetPeerAttestationTimestamp(&ctx, now, ATTESTATION_TIMEOUT); + mgr2->SetPeerAttestationTimestamp(&ctx, now, ATTESTATION_TIMEOUT); + env::test::SentMessages(); // clear sent messages from previous tests + } + + EnclaveMessage Sent() { + auto msgs = env::test::SentMessages(); + CHECK(msgs.size() == 1); + return std::move(msgs[0]); + } + + void Connect1To2() { + e2e::EnclaveToEnclaveMessage* e2e; + ASSERT_EQ(error::OK, mgr1->ConnectToPeer(&ctx, mgr2->ID())); + EnclaveMessage em = Sent(); + ASSERT_EQ(em.inner_case(), EnclaveMessage::kPeerMessage); + ASSERT_EQ(error::OK, mgr2->RecvFromPeer(&ctx, *FromEnclaveMessage(em, mgr1->ID()), &e2e)); + ASSERT_NE(e2e, nullptr); + ASSERT_TRUE(e2e->connected()); + em = Sent(); + ASSERT_EQ(error::OK, mgr1->RecvFromPeer(&ctx, *FromEnclaveMessage(em, mgr2->ID()), &e2e)); + ASSERT_NE(e2e, nullptr); + ASSERT_TRUE(e2e->connected()); + } + + google::protobuf::Arena arena; + std::unique_ptr mgr1; + std::unique_ptr mgr2; + util::UnixSecs now = 1000; +}; + +TEST_F(PeerManagerTest, SuccessfulCommunicationAcrossManagers) { + Connect1To2(); + e2e::EnclaveToEnclaveMessage* e2e; + e2e::EnclaveToEnclaveMessage send; + send.mutable_raft_message()->set_term(123); + ASSERT_EQ(error::OK, mgr1->SendToPeer(&ctx, mgr2->ID(), send)); + EnclaveMessage em = Sent(); + ASSERT_EQ(error::OK, mgr2->RecvFromPeer(&ctx, *FromEnclaveMessage(em, mgr1->ID()), &e2e)); + ASSERT_NE(e2e, nullptr); + ASSERT_EQ(e2e->raft_message().term(), 123); +} + +TEST_F(PeerManagerTest, SendConnected) { + Connect1To2(); + e2e::EnclaveToEnclaveMessage send; + send.set_connected(true); + ASSERT_EQ(error::Peers_SendConnect, mgr1->SendToPeer(&ctx, mgr2->ID(), send)); +} + +TEST_F(PeerManagerTest, AcceptUnparsable) { + PeerMessage msg; + msg.set_syn("this is not parsable protobuf serialized data"); + mgr1->ID().ToString(msg.mutable_peer_id()); + e2e::EnclaveToEnclaveMessage* e2e; + ASSERT_EQ(error::Peers_AcceptParseHandshake, mgr2->RecvFromPeer(&ctx, msg, &e2e)); + ASSERT_EQ(PEER_DISCONNECTED, mgr2->PeerState(&ctx, mgr1->ID())); +} + +TEST_F(PeerManagerTest, RecvConnectToConnected) { + ASSERT_EQ(error::OK, mgr2->ConnectToPeer(&ctx, mgr1->ID())); + Sent(); + Connect1To2(); +} + +TEST_F(PeerManagerTest, FinishConnectUnparsable) { + PeerMessage msg; + msg.set_synack("this is not parsable protobuf serialized data"); + mgr1->ID().ToString(msg.mutable_peer_id()); + e2e::EnclaveToEnclaveMessage* e2e; + ASSERT_EQ(error::OK, mgr2->ConnectToPeer(&ctx, mgr1->ID())); + ASSERT_EQ(error::Peers_FinishParseHandshake, mgr2->RecvFromPeer(&ctx, msg, &e2e)); + ASSERT_EQ(PEER_DISCONNECTED, mgr2->PeerState(&ctx, mgr1->ID())); +} + +TEST_F(PeerManagerTest, ConnectToConnected) { + Connect1To2(); + Connect1To2(); +} + +TEST_F(PeerManagerTest, ReInit) { + ASSERT_EQ(error::Peers_ReInit, mgr1->Init(&ctx)); +} + +TEST_F(PeerManagerTest, NoInit) { + mgr1 = std::make_unique(); + e2e::EnclaveToEnclaveMessage* e2e; + ASSERT_EQ(error::Peers_NoInit, mgr1->ConnectToPeer(&ctx, mgr2->ID())); + PeerMessage msg; + ASSERT_EQ(error::Peers_NoInit, mgr1->RecvFromPeer(&ctx, msg, &e2e)); + e2e::EnclaveToEnclaveMessage send; + ASSERT_EQ(error::Peers_NoInit, mgr1->SendToPeer(&ctx, mgr2->ID(), send)); + ASSERT_FALSE(mgr1->ID().Valid()); +} + +TEST_F(PeerManagerTest, PeerState) { + e2e::EnclaveToEnclaveMessage* e2e; + ASSERT_EQ(PEER_DISCONNECTED, mgr1->PeerState(&ctx, mgr2->ID())); + ASSERT_EQ(error::OK, mgr1->ConnectToPeer(&ctx, mgr2->ID())); + ASSERT_EQ(PEER_CONNECTING, mgr1->PeerState(&ctx, mgr2->ID())); + EnclaveMessage em = Sent(); + ASSERT_EQ(em.inner_case(), EnclaveMessage::kPeerMessage); + ASSERT_EQ(PEER_DISCONNECTED, mgr2->PeerState(&ctx, mgr1->ID())); + ASSERT_EQ(error::OK, mgr2->RecvFromPeer(&ctx, *FromEnclaveMessage(em, mgr1->ID()), &e2e)); + ASSERT_NE(e2e, nullptr); + ASSERT_TRUE(e2e->connected()); + ASSERT_EQ(PEER_CONNECTED, mgr2->PeerState(&ctx, mgr1->ID())); + em = Sent(); + ASSERT_EQ(error::OK, mgr1->RecvFromPeer(&ctx, *FromEnclaveMessage(em, mgr2->ID()), &e2e)); + ASSERT_NE(e2e, nullptr); + ASSERT_TRUE(e2e->connected()); + ASSERT_EQ(PEER_CONNECTED, mgr1->PeerState(&ctx, mgr2->ID())); +} + +TEST_F(PeerManagerTest, TimeoutAttestation) { + Connect1To2(); + ASSERT_EQ(PEER_CONNECTED, mgr1->PeerState(&ctx, mgr2->ID())); + mgr1->SetPeerAttestationTimestamp(&ctx, now, ATTESTATION_TIMEOUT); + ASSERT_EQ(PEER_CONNECTED, mgr1->PeerState(&ctx, mgr2->ID())); + // Go up to but not over threshold. + mgr1->SetPeerAttestationTimestamp(&ctx, now + ATTESTATION_TIMEOUT, ATTESTATION_TIMEOUT); + ASSERT_EQ(PEER_CONNECTED, mgr1->PeerState(&ctx, mgr2->ID())); + // Actually go over threshold. + mgr1->SetPeerAttestationTimestamp(&ctx, now + ATTESTATION_TIMEOUT + 1, ATTESTATION_TIMEOUT); + ASSERT_EQ(PEER_DISCONNECTED, mgr1->PeerState(&ctx, mgr2->ID())); + // Confirm that RST was sent. + EnclaveMessage em = Sent(); + ASSERT_EQ(em.peer_message().inner_case(), PeerMessage::kRst); + ASSERT_EQ(em.peer_message().peer_id(), mgr2->ID().AsString()); +} + +TEST_F(PeerManagerTest, AttestationRefreshStallsTimeout) { + Connect1To2(); + ASSERT_EQ(PEER_CONNECTED, mgr1->PeerState(&ctx, mgr2->ID())); + mgr1->SetPeerAttestationTimestamp(&ctx, now, ATTESTATION_TIMEOUT); + ASSERT_EQ(PEER_CONNECTED, mgr1->PeerState(&ctx, mgr2->ID())); + // Go up to but not over threshold. + mgr1->SetPeerAttestationTimestamp(&ctx, now + ATTESTATION_TIMEOUT, ATTESTATION_TIMEOUT); + mgr2->SetPeerAttestationTimestamp(&ctx, now + ATTESTATION_TIMEOUT, ATTESTATION_TIMEOUT); + ASSERT_EQ(PEER_CONNECTED, mgr1->PeerState(&ctx, mgr2->ID())); + ASSERT_EQ(PEER_CONNECTED, mgr2->PeerState(&ctx, mgr1->ID())); + + ASSERT_EQ(error::OK, mgr2->RefreshAttestation(&ctx)); + EnclaveMessage em = Sent(); + + e2e::EnclaveToEnclaveMessage* e2e; + ASSERT_EQ(error::OK, mgr1->RecvFromPeer(&ctx, *FromEnclaveMessage(em, mgr2->ID()), &e2e)); + ASSERT_TRUE(e2e == nullptr); + + mgr1->SetPeerAttestationTimestamp(&ctx, now + ATTESTATION_TIMEOUT + ATTESTATION_TIMEOUT, ATTESTATION_TIMEOUT); + ASSERT_EQ(PEER_CONNECTED, mgr1->PeerState(&ctx, mgr2->ID())); +} + +} // namespace svr2::peers diff --git a/enclave/proto/clientlog.proto b/enclave/proto/clientlog.proto new file mode 100644 index 0000000..5fb04d2 --- /dev/null +++ b/enclave/proto/clientlog.proto @@ -0,0 +1,26 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +syntax = "proto3"; + +package svr2.client; +option optimize_for = LITE_RUNTIME; + +import "client.proto"; +import "client3.proto"; + +// Log2 is the logged message used by the SVR2 (db2) database. +message Log2 { + bytes backup_id = 1; + client.Request req = 2; +} + +// Log3 is the logged message used by the SVR3 (db3) database. +message Log3 { + bytes backup_id = 1; + client.Request3 req = 2; + // If req.create(), then we need to generate new keys. + // These fields will be filled in with the generated keys. + bytes create_privkey = 3; + bytes create_pubkey = 4; +} diff --git a/enclave/proto/e2e.proto b/enclave/proto/e2e.proto new file mode 100644 index 0000000..73d570a --- /dev/null +++ b/enclave/proto/e2e.proto @@ -0,0 +1,156 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +// Contains messages for enclave-to-enclave (e2e) communication over a peer connection. +syntax = "proto3"; + +package svr2.e2e; +option optimize_for = LITE_RUNTIME; + +import "msgs.proto"; +import "raft.proto"; +import "error.proto"; +import "enclaveconfig.proto"; + +// +// transactions +// + +// +// enclave-to-enclave requests +// + +message Attestation { + bytes evidence = 1; + bytes endorsements = 2; +} + +// +// replica to replica +// + +message EnclaveToEnclaveMessage { + oneof inner { + bool connected = 1; + raft.RaftMessage raft_message = 2; + // attestation_update messages are sent within an EnclaveToEnclaveMessage + // regularly across peer-to-peer links, to keep the remote party happy that the + // local party is still able to attest. It's up to the local party to send these + // regularly; not doing so can shut down a connection. + Attestation attestation_update = 3; + + // Enclave-to-enclave transactions, requests and repsonses + TransactionRequest transaction_request = 4; + TransactionResponse transaction_response = 5; + }; +} + +message ConnectRequest { + Attestation attestation = 1; + bytes handshake = 2; +} + +message TransactionRequest { + uint64 request_id = 1; + oneof inner { + bool ping = 2; // should return status=OK + bool get_raft = 3; + ReplicateStateRequest replicate_state = 4; + ReplicateStatePush replicate_state_push = 5; + bool raft_membership_request = 6; + bool raft_voting_request = 7; + bytes raft_write = 8; + uint64 new_timestamp_unix_secs = 9; + bool raft_removal_request = 10; + } +} + +message TransactionResponse { + uint64 request_id = 1; + oneof inner { + error.Error status = 2; + GetRaftResponse get_raft = 3; + raft.LogLocation raft_membership_response = 4; + raft.LogLocation raft_voting_response = 5; + raft.LogLocation raft_write = 6; + } +} + +message GetRaftResponse { + enclaveconfig.RaftGroupConfig group_config = 1; + raft.ReplicaGroup replica_group = 2; +} + +message DB2RowState { + bytes backup_id = 1; + uint32 tries = 2; + bytes data = 3; + bytes pin = 4; + enum State { + UNINITIATED = 0; + POPULATED = 1; + AVAILABLE = 2; + } + State state = 5; +} + +message DB3RowState { + bytes backup_id = 1; + bytes priv = 2; + uint32 tries = 3; +} + +// --- Replication of State --- +// +// ReplicateStateRequest and ReplicateStatePush allow a new and not-yet-part-of-Raft +// replica to get state from an existing, part-of-Raft replica. By sending a number of +// ..Requests and getting associated ..Responses, the requester will get its Raft log and +// database to a state up to the responder's last committed index. It can then join the +// Raft group with this log/db and become a contributing member. +// +// Requesters move chunk by chunk through the Raft log and the db simultaneously. To do +// this, requesters track cursos pointing to their current location in the raft log +// and db (both initially unset) and provide them on every request. +// +// Responders must ensure +// 1. No uncommitted log entries are returned +// 2. Returned rows will be in the range (req.db_from_key_exclusive, rows[-1]], +// and must reflect the state of the db in that range at the time of the last +// returned log index (resp.entries[-1]) +// +// If a responder has many committed log entries that have already been applied to their db, +// they may have to return no db rows in a response in order to ensure property 2 is met. +// +// Request{} -> +// <- Push{first=oldest responder has, entries=[...], rows=[]} +// <- Push{first=logs_from_idx_inclusive, entries=[...]} +// ... the pusher reaches their commit index, and can now return DB state +// <- Push{first=logs_from_idx_inclusive, entries=[...], rows=[...]} +// <- Push{first=logs_from_idx_inclusive, entries=[...], rows=[...]} +// ... +// <- Push{first=logs_from_idx_inclusive, entries=[...], rows=[...], db_to_end=true} +// <- status=OK +// +// At this time, the requester has all database and log state and may request +// entry into the Raft group. +message ReplicateStateRequest { + uint64 group_id = 1; + uint64 replication_id = 2; +} +message ReplicateStatePush { + uint64 replication_id = 1; + uint64 replication_sequence = 2; + + // Log replication, only committed logs will be returned. + uint64 first_log_idx = 3; + repeated raft.LogEntry entries = 4; + + // Database replication (all rows as of commitment of last row in [entries]) + bool db_to_end = 5; // true if the database range is ..., end_of_db] + repeated bytes rows = 6; // Rows are serialized protos in a database-specific format. + + // Raft membership at the commit index of the source. + // This may be set even if this response doesn't contain + // logs up to the point of the commit idx. + raft.ReplicaGroup committed_membership = 7; +} diff --git a/enclave/proto/raft.proto b/enclave/proto/raft.proto new file mode 100644 index 0000000..5853675 --- /dev/null +++ b/enclave/proto/raft.proto @@ -0,0 +1,76 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +syntax = "proto3"; + +package svr2.raft; +option optimize_for = LITE_RUNTIME; + +message Replica { + bytes peer_id = 1; + bool voting = 2; +} + +// ReplicaGroup contains information on the current configuration (set of +// replicas) for Raft. Importantly, for a particular raft.pb.cc instantiation, +// it must have a deterministic serialization. This means that, while +// serialization might change were the underlying protobuf library version +// to be bumped, the generated code for a static version of the protobuf +// library should be deterministic. In particular, `map` fields should not +// be present in this proto or its children. This determinism is necessary +// since the log's hash chain is updated based on the serialization of this +// proto. +message ReplicaGroup { + repeated Replica replicas = 1; +} + +message RaftMessage { + uint64 group = 1; + uint64 term = 2; + oneof inner { + VoteRequest vote_request = 3; + VoteResponse vote_response = 4; + AppendRequest append_request = 5; + AppendResponse append_response = 6; + bool timeout_now = 7; // force an election timeout on the recipient + }; +} + +message VoteRequest { + uint64 last_log_idx = 1; + uint64 last_log_term = 2; +} + +message VoteResponse { + bool vote_granted = 1; +} + +message AppendRequest { + uint64 prev_log_idx = 1; + uint64 prev_log_term = 2; + uint64 leader_commit = 3; + repeated LogEntry entries = 4; + uint64 leader_promise = 5; +} + +message AppendResponse { + bool success = 1; + uint64 match_idx = 2; + uint64 last_log_idx = 3; + uint64 promise_idx = 4; +} + +message LogEntry { + uint64 term = 1; + oneof inner { + bytes data = 2; + ReplicaGroup membership_change = 3; + } + bytes hash_chain = 4; +} + +message LogLocation { + uint64 term = 1; + uint64 idx = 2; + bytes hash_chain = 3; +} diff --git a/enclave/proto/tests.proto b/enclave/proto/tests.proto new file mode 100644 index 0000000..3c0d8e8 --- /dev/null +++ b/enclave/proto/tests.proto @@ -0,0 +1,11 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +syntax = "proto3"; + +package svr2.tests; +option optimize_for = LITE_RUNTIME; + +message SimplePB { + string str = 1; +} diff --git a/enclave/protobuf b/enclave/protobuf new file mode 160000 index 0000000..dab4d24 --- /dev/null +++ b/enclave/protobuf @@ -0,0 +1 @@ +Subproject commit dab4d24d44eea0f21d6a21a548ee2b8c22b37f4f diff --git a/enclave/protobuf-lite/README.md b/enclave/protobuf-lite/README.md new file mode 100644 index 0000000..1d73705 --- /dev/null +++ b/enclave/protobuf-lite/README.md @@ -0,0 +1,20 @@ +# Compilation of libprotobuf-lite.a + +Rather than rely on libprotobuf to build libprotobuf-lite.a, we just +symlink all necessary files here, then build with our typical +`Makefile.subdir` approach. This makes absolutely sure that we're +only linking to and compiling with the normal mechanisms. + +## Which files? + +If you're a future person that's looking to update the protobuf dependency, +this list of symlinks was found by doing: + +``` +cd ../protobuf +autoreconf -i +./configure +(cd src && make libprotobuf-lite.la) +``` + +and looking at the `CXX` rules that were executed. diff --git a/enclave/protobuf-lite/any_lite.cc b/enclave/protobuf-lite/any_lite.cc new file mode 120000 index 0000000..c498683 --- /dev/null +++ b/enclave/protobuf-lite/any_lite.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/any_lite.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/arena.cc b/enclave/protobuf-lite/arena.cc new file mode 120000 index 0000000..e784660 --- /dev/null +++ b/enclave/protobuf-lite/arena.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/arena.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/arenastring.cc b/enclave/protobuf-lite/arenastring.cc new file mode 120000 index 0000000..1117180 --- /dev/null +++ b/enclave/protobuf-lite/arenastring.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/arenastring.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/arenaz_sampler.cc b/enclave/protobuf-lite/arenaz_sampler.cc new file mode 120000 index 0000000..98358e9 --- /dev/null +++ b/enclave/protobuf-lite/arenaz_sampler.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/arenaz_sampler.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/bytestream.cc b/enclave/protobuf-lite/bytestream.cc new file mode 120000 index 0000000..16a17fa --- /dev/null +++ b/enclave/protobuf-lite/bytestream.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/stubs/bytestream.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/coded_stream.cc b/enclave/protobuf-lite/coded_stream.cc new file mode 120000 index 0000000..98d7f65 --- /dev/null +++ b/enclave/protobuf-lite/coded_stream.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/io/coded_stream.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/common.cc b/enclave/protobuf-lite/common.cc new file mode 120000 index 0000000..554f98d --- /dev/null +++ b/enclave/protobuf-lite/common.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/stubs/common.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/extension_set.cc b/enclave/protobuf-lite/extension_set.cc new file mode 120000 index 0000000..ad3e2fd --- /dev/null +++ b/enclave/protobuf-lite/extension_set.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/extension_set.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/generated_enum_util.cc b/enclave/protobuf-lite/generated_enum_util.cc new file mode 120000 index 0000000..23937b0 --- /dev/null +++ b/enclave/protobuf-lite/generated_enum_util.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/generated_enum_util.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/generated_message_tctable_lite.cc b/enclave/protobuf-lite/generated_message_tctable_lite.cc new file mode 120000 index 0000000..e3a0eaf --- /dev/null +++ b/enclave/protobuf-lite/generated_message_tctable_lite.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/generated_message_tctable_lite.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/generated_message_util.cc b/enclave/protobuf-lite/generated_message_util.cc new file mode 120000 index 0000000..1a5c72f --- /dev/null +++ b/enclave/protobuf-lite/generated_message_util.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/generated_message_util.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/implicit_weak_message.cc b/enclave/protobuf-lite/implicit_weak_message.cc new file mode 120000 index 0000000..1d9800d --- /dev/null +++ b/enclave/protobuf-lite/implicit_weak_message.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/implicit_weak_message.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/inlined_string_field.cc b/enclave/protobuf-lite/inlined_string_field.cc new file mode 120000 index 0000000..a492d64 --- /dev/null +++ b/enclave/protobuf-lite/inlined_string_field.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/inlined_string_field.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/int128.cc b/enclave/protobuf-lite/int128.cc new file mode 120000 index 0000000..f5b767a --- /dev/null +++ b/enclave/protobuf-lite/int128.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/stubs/int128.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/io_win32.cc b/enclave/protobuf-lite/io_win32.cc new file mode 120000 index 0000000..d2edc16 --- /dev/null +++ b/enclave/protobuf-lite/io_win32.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/io/io_win32.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/map.cc b/enclave/protobuf-lite/map.cc new file mode 120000 index 0000000..d390fe8 --- /dev/null +++ b/enclave/protobuf-lite/map.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/map.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/message_lite.cc b/enclave/protobuf-lite/message_lite.cc new file mode 120000 index 0000000..a514c3f --- /dev/null +++ b/enclave/protobuf-lite/message_lite.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/message_lite.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/parse_context.cc b/enclave/protobuf-lite/parse_context.cc new file mode 120000 index 0000000..c4e676d --- /dev/null +++ b/enclave/protobuf-lite/parse_context.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/parse_context.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/repeated_field.cc b/enclave/protobuf-lite/repeated_field.cc new file mode 120000 index 0000000..c3bc781 --- /dev/null +++ b/enclave/protobuf-lite/repeated_field.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/repeated_field.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/repeated_ptr_field.cc b/enclave/protobuf-lite/repeated_ptr_field.cc new file mode 120000 index 0000000..35022e6 --- /dev/null +++ b/enclave/protobuf-lite/repeated_ptr_field.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/repeated_ptr_field.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/status.cc b/enclave/protobuf-lite/status.cc new file mode 120000 index 0000000..a6b2e9d --- /dev/null +++ b/enclave/protobuf-lite/status.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/stubs/status.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/statusor.cc b/enclave/protobuf-lite/statusor.cc new file mode 120000 index 0000000..00c079c --- /dev/null +++ b/enclave/protobuf-lite/statusor.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/stubs/statusor.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/stringpiece.cc b/enclave/protobuf-lite/stringpiece.cc new file mode 120000 index 0000000..40e9110 --- /dev/null +++ b/enclave/protobuf-lite/stringpiece.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/stubs/stringpiece.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/stringprintf.cc b/enclave/protobuf-lite/stringprintf.cc new file mode 120000 index 0000000..e00714f --- /dev/null +++ b/enclave/protobuf-lite/stringprintf.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/stubs/stringprintf.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/strtod.cc b/enclave/protobuf-lite/strtod.cc new file mode 120000 index 0000000..f80909e --- /dev/null +++ b/enclave/protobuf-lite/strtod.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/io/strtod.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/structurally_valid.cc b/enclave/protobuf-lite/structurally_valid.cc new file mode 120000 index 0000000..743040c --- /dev/null +++ b/enclave/protobuf-lite/structurally_valid.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/stubs/structurally_valid.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/strutil.cc b/enclave/protobuf-lite/strutil.cc new file mode 120000 index 0000000..18ef580 --- /dev/null +++ b/enclave/protobuf-lite/strutil.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/stubs/strutil.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/time.cc b/enclave/protobuf-lite/time.cc new file mode 120000 index 0000000..166ce3e --- /dev/null +++ b/enclave/protobuf-lite/time.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/stubs/time.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/wire_format_lite.cc b/enclave/protobuf-lite/wire_format_lite.cc new file mode 120000 index 0000000..f05c70e --- /dev/null +++ b/enclave/protobuf-lite/wire_format_lite.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/wire_format_lite.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/zero_copy_stream.cc b/enclave/protobuf-lite/zero_copy_stream.cc new file mode 120000 index 0000000..3bf66c7 --- /dev/null +++ b/enclave/protobuf-lite/zero_copy_stream.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/io/zero_copy_stream.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/zero_copy_stream_impl.cc b/enclave/protobuf-lite/zero_copy_stream_impl.cc new file mode 120000 index 0000000..003cd3e --- /dev/null +++ b/enclave/protobuf-lite/zero_copy_stream_impl.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/io/zero_copy_stream_impl.cc \ No newline at end of file diff --git a/enclave/protobuf-lite/zero_copy_stream_impl_lite.cc b/enclave/protobuf-lite/zero_copy_stream_impl_lite.cc new file mode 120000 index 0000000..e4456fe --- /dev/null +++ b/enclave/protobuf-lite/zero_copy_stream_impl_lite.cc @@ -0,0 +1 @@ +../protobuf/src/google/protobuf/io/zero_copy_stream_impl_lite.cc \ No newline at end of file diff --git a/enclave/queue/queue.h b/enclave/queue/queue.h new file mode 100644 index 0000000..e526e4a --- /dev/null +++ b/enclave/queue/queue.h @@ -0,0 +1,47 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_QUEUE_QUEUE_H__ +#define __SVR2_QUEUE_QUEUE_H__ + +#include +#include +#include +#include "util/macros.h" + +namespace svr2::queue { + +template +class Queue { + public: + Queue(size_t max_size) : max_size_(max_size) {} + + void Push(T val) { + std::unique_lock lock(mu_); + notfull_.wait(lock, [this]{ return d_.size() < max_size_; }); + d_.emplace_back(std::move(val)); + lock.unlock(); + full_.notify_one(); + } + + T Pop() { + std::unique_lock lock(mu_); + full_.wait(lock, [this]{ return d_.size() > 0; }); + T out = std::move(d_.front()); + d_.pop_front(); + lock.unlock(); + notfull_.notify_one(); + return out; + } + + private: + std::mutex mu_; + std::condition_variable full_; + std::condition_variable notfull_; + std::deque d_; + size_t max_size_; +}; + +} // namespace svr2::queue + +#endif // __SVR2_QUEUE_QUEUE_H__ diff --git a/enclave/queue/tests/queue.cc b/enclave/queue/tests/queue.cc new file mode 100644 index 0000000..8a385b0 --- /dev/null +++ b/enclave/queue/tests/queue.cc @@ -0,0 +1,44 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +//TESTDEP gtest +#include +#include +#include +#include "queue/queue.h" +#include + +namespace svr2::queue { + +class QueueTest : public ::testing::Test {}; + +void QueueReadThread(Queue* q, int n) { + int sum = 0; + for (int i = 0; i < n; i++) { + sum += q->Pop(); + } + ASSERT_EQ(sum, n); +} + +void QueueWriteThread(Queue* q, int n) { + for (int i = 0; i < n; i++) { + q->Push(1); + } +} + +TEST_F(QueueTest, BasicUsage) { + std::vector threads; + Queue q(16); + for (int i = 0; i < 10; i++) { + threads.emplace_back(QueueReadThread, &q, 1000); + } + sleep(1); + for (int i = 0; i < 5; i++) { + threads.emplace_back(QueueWriteThread, &q, 2000); + } + for (int i = 0; i < threads.size(); i++) { + threads[i].join(); + } +} + +} // namespace svr2::queue diff --git a/enclave/raft/internal.h b/enclave/raft/internal.h new file mode 100644 index 0000000..6eb3c9b --- /dev/null +++ b/enclave/raft/internal.h @@ -0,0 +1,84 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_RAFT_INTERNAL_H__ +#define __SVR2_RAFT_INTERNAL_H__ + +#include +#include +#include +#include "peerid/peerid.h" +#include "raft/types.h" +#include "raft/log.h" +#include "raft/membership.h" +#include "util/ticks.h" + +namespace svr2::raft { + +class Raft; // forward declaration + +namespace internal { + +struct ReplicationState { + // \* The next entry to send to each follower. + // VARIABLE nextIndex + LogIdx next_idx; + + // \* The latest entry that each follower has acknowledged is the same as the + // \* leader's. This is used to calculate promiseIndex on the leader. + // VARIABLE matchIndex + LogIdx match_idx; + // The latest entry that each follower has promised. This is used + // to calculate commitIndex on the leader. + LogIdx promise_idx; + + // inflight - this field is very interesting, and is not part of the generic + // Raft protocol. As long as this is set to some LogIdx, we won't send + // additional AppendRequests to this replica. In generic Raft, this would + // not work at all, as a single dropped message would break our ability to + // ever append to its destination replica. However, given that our host-side + // message passing is in-order and lossless (the host will store-and-forward + // our messages, never dropping them, until a message has been received and + // acknowledged), this saves us sending duplicate logs over the network. + // A crucial concern here, though, is that if for some reason a message is + // dropped and we notice it, we must clear this value so that our next + // AppendEntries will go through. + std::optional inflight; + + // send_probe requests that the next AppendEntries request to this peer + // not contain any actual entries, just the log index we think they're + // at. This allows them to correct us without over-sending logs. + bool send_probe; + bool send_heartbeat; + + // the number of ticks since we last got a Raft message from this replica. + util::Ticks last_seen_ticks; +}; + +enum class Role { + FOLLOWER = 1, + CANDIDATE = 2, + LEADER = 3, +}; + +struct FollowerState { + std::optional leader; + util::Ticks election; +}; +struct CandidateState { + // \* The latest entry that each follower has acknowledged is the same as the + // \* leader's. This is used to calculate commitIndex on the leader. + // VARIABLE votesGranted + std::set votes_granted; + util::Ticks election; +}; +struct LeaderState { + std::map followers; + util::Ticks heartbeat; + bool relinquishing; // if true, this leader is trying to become a follower +}; + +} // namespace internal +} // namespace svr2::raft + +#endif // __SVR2_RAFT_INTERNAL_H__ diff --git a/enclave/raft/log.cc b/enclave/raft/log.cc new file mode 100644 index 0000000..31e9018 --- /dev/null +++ b/enclave/raft/log.cc @@ -0,0 +1,178 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include "raft/log.h" +#include "peerid/peerid.h" +#include +#include "util/log.h" +#include "metrics/metrics.h" + +namespace svr2::raft { + +// Guess the size of the logentry in bytes in memory, including the +// container (map node and key) holding it. +size_t Log::logentry_bytes_in_log(const LogEntry& e) { + // Estimate size of membership change proto. + size_t mem_change = 0; + if (e.has_membership_change()) { + mem_change += + sizeof(ReplicaGroup) + + // for each replica: + e.membership_change().replicas_size() * ( + // each replica should point to a string of this size + sizeof(peerid::PeerID) + + // plus the size of the replica object itself + sizeof(Replica)); + } + if (e.hash_chain().size()) { + mem_change += e.hash_chain().size() + + sizeof(std::string); + } + return mem_change + + // Size of the value + sizeof(LogEntry) + + // Size of the data on the heap, which should be at least a length and the actual bytes + sizeof(std::string) + e.data().size(); +} + +Log::Log(size_t max_bytes) : oldest_stored_idx_(1), curr_bytes_(0), max_bytes_(max_bytes) { + GAUGE(raft, log_total_size)->Set(max_bytes_); + UpdateMetrics(); +} + +error::Error Log::CancelFrom(LogIdx from_log_idx) { + if (from_log_idx < oldest_stored_idx_) { + return COUNTED_ERROR(Raft_CancelingBeforeFirst); + } + size_t index = from_log_idx - oldest_stored_idx_; + entries_.resize(index); + return error::OK; +} + +Log::Iterator Log::At(LogIdx idx) const { + size_t di = idx < oldest_stored_idx_ ? entries_.size() : idx - oldest_stored_idx_; + return Iterator(this, di); +} + +LogIdx Log::Iterator::Index() const { + if (!Valid()) return 0; + return log_->oldest_stored_idx_ + deque_index_; +} + +bool Log::Iterator::Valid() const { + return deque_index_ < log_->entries_.size(); +} + +const LogEntry* Log::Iterator::Entry() const { + if (!Valid()) return nullptr; + return &log_->entries_[deque_index_]; +} + +TermId Log::Iterator::Term() const { + if (!Valid()) return 0; + return Entry()->term(); +} + +size_t Log::Iterator::SerializedSize() const { + if (!Valid()) return 0; + // We called ByteSizeLong when we appended this log entry, and its + // size can't have changed since, so GetCachedSize will give us the + // correct value. Guaranteed to be >= 1 since we check that term() + // is nonzero. Must return a value <= INT_MAX, since GetCachedSize + // returns an int. + return (size_t) Entry()->GetCachedSize(); +} + +size_t Log::Iterator::MemSize() const { + if (!Valid()) return 0; + return logentry_bytes_in_log(*Entry()); +} + +void Log::Iterator::Next() { + if (Valid()) { ++deque_index_; } +} + +void Log::Iterator::Prev() { + if (Valid()) { --deque_index_; } +} + +LogIdx Log::oldest_stored_idx() const { + if (entries_.size() == 0) { return 0; } + return oldest_stored_idx_; +} + +LogIdx Log::last_idx() const { + if (entries_.size() == 0) { return 0; } + return oldest_stored_idx_ + entries_.size() - 1; +} + +LogIdx Log::last_term() const { + if (entries_.size() == 0) { return 0; } + return At(last_idx()).Term(); +} + +error::Error Log::Append(const LogEntry& log, LogIdx maybe_truncate_to) { + if (log.term() == 0) { + return COUNTED_ERROR(Raft_AppendWithoutTerm); + } + if (log.hash_chain().size() != 32) { + return COUNTED_ERROR(Raft_NoHashChainInAppend); + } + size_t mem = logentry_bytes_in_log(log); + while (curr_bytes_ + mem > max_bytes_) { + if (!RemoveOldestLogOlderThan(maybe_truncate_to)) { + return COUNTED_ERROR(Raft_LogOutOfSpace); + } + } + // Don't allow larger than 2G, since that'll mess up our call to GetCachedSize + // which returns an int. + if (log.ByteSizeLong() > INT_MAX) { + return COUNTED_ERROR(Raft_LogEntryTooLarge); + } + // This creates a copy of the log, which is important since the + // original log we got the reference from may fall out of scope before + // we do. + entries_.emplace_back(log); + // Re-compute byte size, in the new location. + entries_.rbegin()->ByteSizeLong(); + curr_bytes_ += mem; + UpdateMetrics(); + return error::OK; +} + +bool Log::RemoveOldestLogOlderThan(LogIdx truncate_to) { + if (oldest_stored_idx_ >= truncate_to) return false; + size_t mem = logentry_bytes_in_log(entries_.front()); + entries_.pop_front(); + oldest_stored_idx_++; + curr_bytes_ -= mem; + return true; +} + +void Log::UpdateMetrics() { + GAUGE(raft, log_oldest_stored_log_index)->Set(oldest_stored_idx()); + GAUGE(raft, log_last_log_index)->Set(last_idx()); + GAUGE(raft, log_last_log_term)->Set(last_term()); + GAUGE(raft, log_size)->Set(curr_bytes_); + GAUGE(raft, log_entries)->Set(entries_.size()); +} + +bool Log::MostRecentHash(std::array* out) { + for (auto iter = At(last_idx()); iter.Valid(); iter.Prev()) { + if (iter.Entry()->hash_chain().size() == out->size()) { + std::copy(iter.Entry()->hash_chain().cbegin(), iter.Entry()->hash_chain().cend(), out->begin()); + return true; + } + } + return false; +} + +error::Error Log::SetNextIdx(LogIdx idx) { + if (entries_.size()) { + return COUNTED_ERROR(Raft_SetNextOnNonemptyLog); + } + oldest_stored_idx_ = idx; + return error::OK; +} + +} // namespace svr2::raft diff --git a/enclave/raft/log.h b/enclave/raft/log.h new file mode 100644 index 0000000..5209f8a --- /dev/null +++ b/enclave/raft/log.h @@ -0,0 +1,112 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_RAFT_LOG_H__ +#define __SVR2_RAFT_LOG_H__ + +#include +#include +#include "raft/types.h" +#include "proto/error.pb.h" +#include "proto/raft.pb.h" +#include "util/macros.h" + +namespace svr2::raft { + +// Raft log storage. +class Log { + public: + Log(size_t max_bytes); + + class Iterator { + public: + // Returns true if this iterator points to a valid entry. If !Valid, + // other functions will all return default zero-values. + bool Valid() const; + // Which index we point to. + LogIdx Index() const; + // Return the log entry at this index, nullptr if !Valid. + const LogEntry* Entry() const; + // Return the term ID at this index. + TermId Term() const; + // Estimated serialized size of the LogEntry proto. + size_t SerializedSize() const; + // Estimated in-memory size of this log entry. + size_t MemSize() const; + // Move the iterator forward, may invalidate if we're at the end of the log. + // Typical usage: + // + // for (auto iter = log->At(123); iter.Valid(); iter.Next()) { ... } + void Next(); + // Move the iterator backwards, may invalidate if we're at the beginning of the log. + // Typical usage: + // + // for (auto iter = log->At(123); iter.Valid(); iter.Prev()) { ... } + void Prev(); + private: + friend class Log; + Iterator(const Log* log, size_t di) : log_(log), deque_index_(di) {} + const Log* log_; + size_t deque_index_; + }; + // Returns a new iterator. Any change to the log (Append, RemoveOldestLogOlderThan, + // CancelFrom) may invalidate this iterator. + Iterator At(LogIdx idx) const; + + // oldest_stored_idx returns the index of the least recent entry this log stores. + // It's incremented by a successful call to RemoveOldestLogOlderThan. + LogIdx oldest_stored_idx() const; + // last_idx returns the index of the most recent entry this log stores. + // It's incremented by a successful call to Append. + LogIdx last_idx() const; + // next_idx returns the index that a successfully Append'd entry will have. + LogIdx next_idx() const { return last_idx() + 1; } + TermId last_term() const; + size_t log_data_length_bytes() const { return curr_bytes_; } + + // Append a log to the Log. May return one of the following: + // - Raft_LogOutOfSpace: The log is currently full + // - Raft_LogEntryTooLarge: Rejecting the log entry because it's HUGE + // - various other errors? + // While appending this log, we're allowed to truncate old logs up to (but + // not including) `maybe_truncate_to` if we run out of space. We'll only + // return Raft_LogOutOfSpace if this fails to make enough space for the new + // log entry. + error::Error Append(const LogEntry& log, LogIdx maybe_truncate_to); + + // CancelFrom cancels (removes) all logs from the given log index on, + // leaving only entries of [start,from_log_idx) remaining in the log. + // This is necessary in cases where an old leader's uncommitted logs are + // overridden by a new leader. + error::Error CancelFrom(LogIdx from_log_idx); + + // Get the most recent hash chain value from the log. + bool MostRecentHash(std::array* out); + + // If this log is empty, set what the next index will be. This is useful + // in cases where we're replicating an already-truncated log. + error::Error SetNextIdx(LogIdx idx); + + // Return true if there are no log entries in this log. + bool empty() const { return entries_.size() == 0; } + + public_for_test: + static size_t logentry_bytes_in_log(const LogEntry& e); + + private: + // RemoveOldestLogOlderThan removes the oldest log from the Log and returns + // true. It will return false if there is no log older than [truncate_to]. + bool RemoveOldestLogOlderThan(LogIdx truncate_to); + + friend class Iterator; + void UpdateMetrics(); + + std::deque entries_; + LogIdx oldest_stored_idx_; + size_t curr_bytes_; + size_t max_bytes_; +}; + +} // namespace svr2::raft + +#endif // __SVR2_RAFT_LOG_H__ diff --git a/enclave/raft/membership.cc b/enclave/raft/membership.cc new file mode 100644 index 0000000..371a971 --- /dev/null +++ b/enclave/raft/membership.cc @@ -0,0 +1,118 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include "raft/membership.h" +#include "util/log.h" +#include "metrics/metrics.h" +#include + +namespace svr2::raft { + +std::pair, error::Error> Membership::FromProto(const ReplicaGroup& group) { + std::unique_ptr out(new Membership()); + + for (const auto& replica : group.replicas()) { + peerid::PeerID p; + error::Error peer_err = p.FromString(replica.peer_id()); + if (peer_err != error::OK) { + return std::make_pair(nullptr, peer_err); + } + if (out->all_replicas_.count(p)) { + return std::make_pair(nullptr, COUNTED_ERROR(Membership_DuplicateReplicaInReplicaGroup)); + } + out->all_replicas_.insert(p); + if (replica.voting()) { out->voting_replicas_.insert(p); } + } + return std::make_pair(std::move(out), error::OK); +} + +std::unique_ptr Membership::First(const peerid::PeerID& me) { + std::unique_ptr out(new Membership()); + out->voting_replicas_.insert(me); + out->all_replicas_.insert(me); + return out; +} + +// Returns the size of the set [a-b], IE: set a with all elements of set b removed from it. +size_t SetDiffSize(const std::set& a, const std::set& b) { + auto a_iter = a.cbegin(); + auto b_iter = b.cbegin(); + size_t out = 0; + while (a_iter != a.cend()) { + if (b_iter == b.cend()) { + ++out; + ++a_iter; + } else if (*a_iter < *b_iter) { + ++a_iter; + ++out; + } else if (*b_iter < *a_iter) { + ++b_iter; + } else { // *a_iter == *b_iter + ++a_iter; + ++b_iter; + } + } + return out; +} + +error::Error Membership::ValidProgressionForLeader( + const peerid::PeerID& leader, + const Membership& from, + const Membership& to, + size_t super_majority) { + if (from.voting_replicas_.size() > super_majority && to.voting_replicas_.size() <= super_majority) { + return COUNTED_ERROR(Membership_SuperMajorityLost); + } + size_t voting_additions = SetDiffSize(to.voting_replicas_, from.voting_replicas_); + std::vector voting_removals; + std::set_difference( + from.voting_replicas_.begin(), from.voting_replicas_.end(), + to.voting_replicas_.begin(), to.voting_replicas_.end(), + std::back_inserter(voting_removals)); + size_t all_additions = SetDiffSize(to.all_replicas_, from.all_replicas_); + std::vector all_removals; + std::set_difference( + from.all_replicas_.begin(), from.all_replicas_.end(), + to.all_replicas_.begin(), to.all_replicas_.end(), + std::back_inserter(all_removals)); + size_t all_changes = voting_additions + voting_removals.size() + all_additions + all_removals.size(); + if (to.voting_replicas_.size() == 0 || to.all_replicas_.size() == 0) { + return COUNTED_ERROR(Membership_EmptySet); + } + if (all_changes == 2 && voting_removals.size() == 1 && all_removals.size() == 1 && voting_removals[0] == all_removals[0]) { + // We allow there to be exactly two changes in the case where they are: + // * remove peer X from voting replicas + // * remove the same peer X from all replicas + // We allow this so that, on shutdown, a replica can request to be fully + // removed from the Raft group in a single step. + } else if (all_changes > 1) { + return COUNTED_ERROR(Membership_TooManyMembershipChanges); + } + if (all_changes == 0) { + return COUNTED_ERROR(Membership_NoMembershipChanges); + } + if (!to.voting_replicas_.count(leader)) { + return COUNTED_ERROR(Membership_LeaderRemovedFromVoting); + } + if (!to.all_replicas_.count(leader)) { + return COUNTED_ERROR(Membership_LeaderRemovedFromAll); + } + if (SetDiffSize(to.all_replicas_, to.voting_replicas_) != to.all_replicas_.size() - to.voting_replicas_.size()) { + return COUNTED_ERROR(Membership_VotingNotSubset); + } + return error::OK; +} + +ReplicaGroup Membership::AsProto() const { + ReplicaGroup g; + for (auto peer : all_replicas_) { + auto r = g.add_replicas(); + peer.ToString(r->mutable_peer_id()); + if (voting_replicas_.count(peer)) { + r->set_voting(true); + } + } + return g; +} + +} // namespace svr2::raft diff --git a/enclave/raft/membership.h b/enclave/raft/membership.h new file mode 100644 index 0000000..6fd5863 --- /dev/null +++ b/enclave/raft/membership.h @@ -0,0 +1,52 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_RAFT_MEMBERSHIP_H__ +#define __SVR2_RAFT_MEMBERSHIP_H__ + +#include +#include "peerid/peerid.h" +#include "proto/error.pb.h" +#include "proto/raft.pb.h" + +namespace svr2::raft { + +size_t SetDiffSize(const std::set& a, const std::set& b); + +class Membership { + public: + DELETE_ASSIGN(Membership); + // First returns a membership from a proto, considering this to be + // the first membership of Raft. + static std::unique_ptr First(const peerid::PeerID& me); + // FromProto does minimal error checking and returns the membership as + // ReplicaGroup describes it. + static std::pair, error::Error> FromProto(const ReplicaGroup& group); + + const std::set& all_replicas() const { return all_replicas_; } + const std::set& voting_replicas() const { return voting_replicas_; } + + // ValidProgressionForLeader checks if a change in membership from [from] to + // [to] should be accepted by raft leader [leader]. If so, returns error::OK. + // If not, returns an error explaining the issue. + static error::Error ValidProgressionForLeader( + const peerid::PeerID& leader, + const Membership& from, + const Membership& to, + size_t super_majority); + + ReplicaGroup AsProto() const; + + public_for_test: + Membership(const Membership& other) = default; // allow copy + private: + Membership() = default; + // all_replicas includes all peers, including me. + std::set all_replicas_; + // voting_replicas includes all replicas that can vote. + std::set voting_replicas_; +}; + +} // namespace svr2::raft + +#endif // __SVR2_RAFT_MEMBERSHIP_H__ diff --git a/enclave/raft/raft.cc b/enclave/raft/raft.cc new file mode 100644 index 0000000..370b06a --- /dev/null +++ b/enclave/raft/raft.cc @@ -0,0 +1,1147 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include "raft/raft.h" +#include +#include "util/log.h" +#include +#include "metrics/metrics.h" +#include "hmac/hmac.h" +#include "util/constant.h" +#include "util/bytes.h" + +#define MELOG(x) LOG(x) << "(" << me().DebugString() << ") " + +namespace svr2::raft { + +Raft::Raft( + GroupId group, + const peerid::PeerID& me, + std::unique_ptr mem, + std::unique_ptr log, + const enclaveconfig::RaftConfig& config, + bool committed_log, + size_t super_majority) + : group_(group), + me_(me), + membership_(std::move(mem)), + config_(config), + last_applied_(committed_log ? log->last_idx() : 0), + current_term_(0), + log_(std::move(log)), + commit_idx_(committed_log ? log_->last_idx() : 0), + promise_idx_(committed_log ? log_->last_idx() : 0), + super_majority_(super_majority) { + SetRole(internal::Role::FOLLOWER); + follower_.election = RandomElectionTimeout(); + GAUGE(raft, commit_index)->Set(commit_idx_); + GAUGE(raft, promise_index)->Set(promise_idx_); + if (voting() && membership().voting_replicas().size() == 1) { + // This is a one-instance replica and I'm voting, become leader. + context::Context ctx; + ElectionTimeout(&ctx); + MaybeChangeStateAndSendMessages(&ctx); + CHECK(sendable_messages_.size() == 0); + } +} + +size_t Raft::membership_quorum_size() const { + return quorum_size(membership().voting_replicas().size(), super_majority_); +} + +size_t Raft::quorum_size(size_t voting_replicas, size_t super_majority) { + return std::min( + voting_replicas, + (voting_replicas - super_majority) / 2 + 1 + super_majority); +} + +static const char* RoleName(internal::Role r) { + switch (r) { + case internal::Role::LEADER: return "LEADER"; + case internal::Role::CANDIDATE: return "CANDIDATE"; + case internal::Role::FOLLOWER: return "FOLLOWER"; + } + return "UNKNOWN_ROLE"; +} + +void Raft::SetRole(internal::Role r) { + MELOG(INFO) << "Raft switching to role " << RoleName(r) << " at term " << current_term_; + role_ = r; + leader_ = {}; + follower_ = {}; + candidate_ = {}; + GAUGE(raft, role)->Set(static_cast(r)); +} + +std::pair Raft::TakeCommittedLog() { + if (last_applied_ >= commit_idx_) { + return std::make_pair(0, LogEntry()); + } + last_applied_++; + // If it's committed, we should have it. + auto iter = log_->At(last_applied_); + CHECK(iter.Valid()); + LogEntry out(*iter.Entry()); + return std::make_pair(last_applied_, std::move(out)); +} + +std::optional Raft::leader() const { + switch (role_) { + case internal::Role::FOLLOWER: return follower_.leader; + case internal::Role::CANDIDATE: return std::optional(); + case internal::Role::LEADER: return std::optional(me_); + default: CHECK(nullptr == "Raft state without valid role"); + } +} +void Raft::set_election_timeout(util::Ticks t) { + config_.set_election_ticks(t); + follower_.election = std::min(t, follower_.election); + candidate_.election = std::min(t, candidate_.election); +} +void Raft::set_heartbeat_timeout(util::Ticks t) { + config_.set_heartbeat_ticks(t); + leader_.heartbeat = t; +} +void Raft::TimerTick(context::Context* ctx) { + switch (role_) { + case internal::Role::FOLLOWER: + if (0 >= --follower_.election) { + LOG(INFO) << "follower election timeout"; + ElectionTimeout(ctx); + } + break; + case internal::Role::CANDIDATE: + if (0 >= --candidate_.election) { + LOG(INFO) << "candidate election timeout"; + ElectionTimeout(ctx); + } + break; + case internal::Role::LEADER: + for (auto i = leader_.followers.begin(); i != leader_.followers.end(); ++i) { + if (i->second.last_seen_ticks != util::InvalidTicks) { + i->second.last_seen_ticks++; + } + } + if (0 >= --leader_.heartbeat) { + LOG(VERBOSE) << "leader sending heartbeat"; + for (auto i = leader_.followers.begin(); i != leader_.followers.end(); ++i) { + i->second.send_heartbeat = true; + } + leader_.heartbeat = config_.heartbeat_ticks(); + } + break; + } + MaybeChangeStateAndSendMessages(ctx); +} + +void Raft::ResetPeer(context::Context* ctx, const peerid::PeerID& id) { + if (!membership().all_replicas().count(id)) { + // Don't bother doing anything if this isn't one of our Raft peers. + return; + } + switch (role_) { + case internal::Role::FOLLOWER: + return; + case internal::Role::CANDIDATE: + // Since this peer may have lost messages, it may have lost our + // request for a vote, so resend it. + AddSendableMessage(SendableRaftMessage::Reply(id, RequestVoteMessage(ctx))); + return; + case internal::Role::LEADER: { + auto finder = leader_.followers.find(id); + if (finder == leader_.followers.end()) { return; } + internal::ReplicationState& state = finder->second; + state.next_idx = log_->last_idx() + 1; + state.send_probe = true; + state.send_heartbeat = true; + state.inflight.reset(); + // We don't reset last_seen_ticks yet, because we haven't gotten + // a RAFT message from them. But the above means that we will + // send them a message, so we should reset it soon when we get our reply. + } return; + } +} +void Raft::Reconfigure(const enclaveconfig::RaftConfig& config) { + MELOG(INFO) << "reconfiguring raft"; + config_ = config; + set_election_timeout(config_.election_ticks()); + set_heartbeat_timeout(config.heartbeat_ticks()); +} + +void Raft::RelinquishLeadership(context::Context* ctx) { + if (role_ != internal::Role::LEADER || leader_.relinquishing) { return; } + + // Append a noop to the end of the log. Since we then wait for the first + // replica that reaches the end of our log, this makes sure that we find a replica + // that is up and running at the time of this call. Otherwise, it's possible that + // we could have a quiescent Raft group and the replica we choose may no longer + // be responding. + ClientRequestInternal(ctx->Protobuf()); + + leader_.relinquishing = true; + MaybeChangeStateAndSendMessages(ctx); +} + +std::pair Raft::LogAppend(const LogEntry& entry) { + error::Error err = error::OK; + std::unique_ptr new_uncommitted_membership; + if (entry.has_membership_change()) { + auto [mem, err] = Membership::FromProto(entry.membership_change()); + if (err != error::OK) { + COUNTER(raft, logs_append_failure)->Increment(); + LOG(ERROR) << "failing to append invalid membership change in Raft uncommitted log at idx=" + << log_->next_idx() << ", error=" << err; + return std::make_pair(LogLocation(), err); + } + new_uncommitted_membership = std::move(mem); + } + if (error::OK != (err = log_->Append(entry, last_applied_))) { + // Some unhandleable Raft error occurred. + COUNTER(raft, logs_append_failure)->Increment(); + return std::make_pair(LogLocation(), err); + } + LogLocation loc; + loc.set_term(current_term_); + loc.set_idx(log_->last_idx()); + loc.set_hash_chain(entry.hash_chain()); + if (new_uncommitted_membership.get() != nullptr) { + AddUncommittedMembership(loc.idx(), std::move(new_uncommitted_membership)); + } + COUNTER(raft, logs_append_success)->Increment(); + return std::make_pair(loc, error::OK); +} + +// +// -- raft TLA+ parallel code -- +// the code below is so similar to Raft's TLA+ code that the TLA+ is provided +// in the right-hand column for sections which correspond almost exactly. code +// is provided in the same order as the TLA+ so that the reader can follow. +// + +// +// \* Define state transitions +// + +// \* Server i times out and starts a new election. +void Raft::ElectionTimeout(context::Context* ctx) { + if (!voting()) { + LOG(WARNING) << "not a voting member, skipping election request"; + // If we're a non-voting follower, reset our election ticks. + follower_.election = RandomElectionTimeout(); + return; + } + COUNTER(raft, election_timeouts)->Increment(); + switch (role_) { + case internal::Role::CANDIDATE: + case internal::Role::FOLLOWER: { + // /\ state[i] \in {Follower, Candidate} + // /\ currentTerm' = [currentTerm EXCEPT ![i] = currentTerm[i] + 1] + // \* Most implementations would probably just set the local vote + // \* atomically, but messaging localhost for it is weaker. + current_term_++; + GAUGE(raft, current_term)->Set(current_term_); + // /\ votedFor' = [votedFor EXCEPT ![i] = Nil] + voted_for_ = me_; + // /\ votesGranted' = [votesGranted EXCEPT ![i] = {}] + std::set votes_granted; + votes_granted.insert(me_); + + // /\ state' = [state EXCEPT ![i] = Candidate] + SetRole(internal::Role::CANDIDATE); + candidate_ = { + .votes_granted = std::move(votes_granted), + .election = RandomElectionTimeout(), + }; + + MELOG(INFO) << "became candidate at term " << current_term_; + AddSendableMessage(SendableRaftMessage::Broadcast(RequestVoteMessage(ctx))); + break; + } + default: + break; + } +} + +// \* Candidate i sends j a RequestVote request. +RaftMessage* Raft::RequestVoteMessage(context::Context* ctx) { + // RequestVote(i,j) == + // /\ state[i] = Candidate + CHECK(role_ == internal::Role::CANDIDATE); + // /\ Send([ + auto msg = ctx->Protobuf(); + msg->set_group(group_); + // mterm |-> currentTerm[i], + msg->set_term(current_term_); + // mtype |-> RequestVoteRequest, + auto vote_req = msg->mutable_vote_request(); + // mlastLogTerm |-> LastTerm(log[i]), + vote_req->set_last_log_term(log_->last_term()); + // mlastLogIndex |-> Len(log[i]), + vote_req->set_last_log_idx(log_->last_idx()); + return msg; +} + +// \* Leader i sends j an AppendEntries request containing up to 1 entry. +// \* While implementations may want to send more than 1 at a time, this spec uses +// \* just 1 because it minimizes atomic regions without loss of generality. +void Raft::AppendEntries(context::Context* ctx, const peerid::PeerID& peer) { + // AppendEntries(i, j) == + // /\ state[i] = Leader + if (role_ != internal::Role::LEADER) { return; } + // /\ i /= j + if (0 == leader_.followers.count(peer)) { return; } + internal::ReplicationState& replication = leader_.followers[peer]; + uint64_t last_log_idx = log_->last_idx(); + uint64_t next_idx = replication.next_idx; + bool send_entries = last_log_idx >= next_idx && !replication.send_probe; + if (!send_entries && !replication.send_heartbeat && !replication.send_probe) { return; } + if (replication.inflight.has_value()) { return; } + MELOG(VERBOSE) << "sending appendentries to " << peer; + + // /\ LET prevLogIndex == nextIndex[i][j] - 1 + LogIdx prev_log_idx = next_idx - 1; + // prevLogTerm == IF prevLogIndex > 0 THEN + // log[i][prevLogIndex].term + // ELSE + // 0 + uint64_t prev_log_term = prev_log_idx == 0 ? 0 : log_->At(prev_log_idx).Term(); + if (prev_log_term == 0 && prev_log_idx != 0) { + LOG(ERROR) << "missing log " << prev_log_idx << " to send to " << peer; + return; + } + std::vector entries; + LogIdx last_entry = prev_log_idx; + // \* Send up to 1 entry, constrained by the end of the log. + if (send_entries) { + size_t max_entries_size = config_.replication_chunk_bytes(); + uint64_t start_entry = next_idx; + uint64_t limit_entry = std::min(last_log_idx + 1, start_entry+max_entries_size); + // entries == SubSeq(log[i], nextIndex[i][j], lastEntry) + for (uint64_t entry_idx = start_entry; entry_idx < limit_entry; entry_idx++) { + const LogEntry* e = log_->At(entry_idx).Entry(); + if (e == nullptr) { + LOG(ERROR) << "error fetching raft log " << entry_idx << " to send to " << peer; + break; + } + entries.emplace_back(*e); + } + // lastEntry == Min({Len(log[i]), nextIndex[i][j]}) + last_entry = prev_log_idx + entries.size(); + } + + // IN Send([ + auto msg = ctx->Protobuf(); + msg->set_group(group_); + // mterm |-> currentTerm[i], + msg->set_term(current_term_); + // mtype |-> AppendEntriesRequest, + auto append = msg->mutable_append_request(); + // mprevLogIndex |-> prevLogIndex, + append->set_prev_log_idx(prev_log_idx); + // mprevLogTerm |-> prevLogTerm, + append->set_prev_log_term(prev_log_term); + // mentries |-> entries, + for (size_t i = 0; i < entries.size(); i++) { + *append->add_entries() = std::move(entries[i]); + } + // mcommitIndex |-> Min({commitIndex[i], lastEntry}), + append->set_leader_commit(std::min(commit_idx_, last_entry)); + append->set_leader_promise(std::min(promise_idx_, last_entry)); + + replication.send_heartbeat = false; + replication.inflight = last_entry; + AddSendableMessage(SendableRaftMessage::Reply(peer, msg)); +} + +void Raft::MaybeBecomeLeader(context::Context* ctx) { + // BecomeLeader(i) == + // /\ state[i] = Candidate + if (role_ != internal::Role::CANDIDATE) { return; } + // /\ votesGranted[i] \in Quorum + if (candidate_.votes_granted.size() < membership_quorum_size()) { return; } + LOG(INFO) << "becoming leader at " << current_term_; + SetRole(internal::Role::LEADER); + leader_ = { + .heartbeat = 0, + }; + for (auto peer : membership().all_replicas()) { + if (peer == me_) continue; + leader_.followers[peer] = { + // /\ nextIndex' = [nextIndex EXCEPT ![i] = [j \in Server |-> Len(log[i]) + 1]] + .next_idx = log_->next_idx(), + // /\ matchIndex' = [matchIndex EXCEPT ![i] = [j \in Server |-> 0]] + }; + } + // append a noop in the new term to commit entries from past terms (Raft Section 5.4.2) + ClientRequestInternal(ctx->Protobuf()); +} + +void Raft::AddUncommittedMembership( + LogIdx idx, std::unique_ptr uncommitted_membership) { + // Uncommitted memberships should always be stored in log index order. + CHECK(uncommitted_memberships_.size() == 0 + || uncommitted_memberships_.back().first < idx); + uncommitted_memberships_.emplace_back(idx, std::move(uncommitted_membership)); + HandleMembershipChange(); +} + +void Raft::HandleMembershipChange() { + if (role_ == internal::Role::LEADER) { + // If there's any new followers (voting or not) in the new uncommitted + // membership, add them to the current leader's [followers] map. + for (auto peer : membership().all_replicas()) { + if (peer != me_ && leader_.followers.count(peer) == 0) { + // Same as in MaybeBecomeLeader: + leader_.followers[peer] = { + .next_idx = log_->next_idx(), + .send_probe = true, + .send_heartbeat = true, + // We set this to a number high enough that we won't immediately add + // this replica to the set of voting replicas, and low enough that we + // won't immediately kick them for being unresponsive. + .last_seen_ticks = config_.election_ticks(), + }; + } + } + // We probably don't need to remove followers from this leader, but + // it keeps our followers==all_replicas story intact, so it seems + // safer to do it. + for (auto iter = leader_.followers.begin(); iter != leader_.followers.end(); ) { + if (membership().all_replicas().count(iter->first) == 0) { + iter = leader_.followers.erase(iter); + } else { + ++iter; + } + } + } + LOG(INFO) << "Membership change"; + for (auto peer : membership().all_replicas()) { + LOG(INFO) << "* " << peer << (membership().voting_replicas().count(peer) ? " (voting)" : ""); + } +} + +std::array Raft::NextHash(const LogEntry& next_entry) { + std::array previous_hash = {0}; + log_->MostRecentHash(&previous_hash); + // We add prefixes to each input, so that inputs with the same serialization are distinct. + switch (next_entry.inner_case()) { + case LogEntry::kData: + return hmac::HmacSha256(previous_hash, "\001" + next_entry.data()); + case LogEntry::kMembershipChange: { + std::string serialized = next_entry.membership_change().SerializeAsString(); + return hmac::HmacSha256(previous_hash, "\002" + serialized); + } + case LogEntry::INNER_NOT_SET: + return hmac::HmacSha256(previous_hash, "\003"); + } +} + +// \* Leader i receives a client request to add v to the log. +std::pair Raft::ClientRequestInternal(LogEntry* entry) { + // NON-TLA+: Set up hash chain for entry: + auto new_hash = NextHash(*entry); + entry->set_hash_chain(util::ByteArrayToString(new_hash)); + // ClientRequest(i, v) == + // /\ LET entry == [term |-> currentTerm[i], + entry->set_term(current_term_); + // value |-> v] + // /\ state[i] = Leader + if (role_ != internal::Role::LEADER || leader_.relinquishing) { + return std::make_pair(LogLocation(), COUNTED_ERROR(Raft_AppendEntryNotLeader)); + } + // newLog == Append(log[i], entry) + return LogAppend(*entry); + // IN log' = [log EXCEPT ![i] = newLog] +} + +std::pair Raft::ClientRequest(context::Context* ctx, const std::string& data) { + auto entry = ctx->Protobuf(); + *entry->mutable_data() = data; + auto out = ClientRequestInternal(entry); + MaybeChangeStateAndSendMessages(ctx); + return out; +} + +std::pair Raft::ReplicaGroupChange(context::Context* ctx, const ReplicaGroup& g) { + // We will check role again in ClientRequestInternal, but we + // do some checks here that assume leadership, so check here + // before we do those. + if (role_ != internal::Role::LEADER || leader_.relinquishing) { + MELOG(VERBOSE) << "received ReplicaGroupRequest but not leader"; + return std::make_pair(LogLocation(), COUNTED_ERROR(Raft_AppendEntryNotLeader)); + } + // We allow only one uncommitted membership change within uncommitted + // logs. If we already have one, reject this request. + if (uncommitted_memberships_.size()) { + return std::make_pair(LogLocation(), COUNTED_ERROR(Raft_MembershipAlreadyChanging)); + } + // Is this change actually valid? + auto [next, err] = Membership::FromProto(g); + if (err != error::OK) { + return std::make_pair(LogLocation(), err); + } + // Does this change do anything detrimental, like remove the voting rights + // of the current leader, emptying out all voters, etc? + err = Membership::ValidProgressionForLeader(me_, *membership_, *next, super_majority_); + if (err != error::OK) { + return std::make_pair(LogLocation(), err); + } + // If we're here, we're going to attempt to move forward with this request. + auto entry = ctx->Protobuf(); + *entry->mutable_membership_change() = g; + LOG(VERBOSE) << "Requesting raft membership change"; + auto out = ClientRequestInternal(entry); + MaybeChangeStateAndSendMessages(ctx); + return out; +} + +// \* Leader i advances its commitIndex. +// \* This is done as a separate step from handling AppendEntries responses, +// \* in part to minimize atomic regions, and in part so that leaders of +// \* single-server clusters are able to mark entries committed. +void Raft::MaybeAdvanceCommitIndex() { + // AdvanceCommitIndex(i) == + // /\ state[i] = Leader + if (role_ != internal::Role::LEADER) { return; } + // /\ LET \* The set of servers that agree up through index. + // Agree(index) == {i} \cup {k \in Server : matchIndex[i][k] >= index} + // \* The maximum indexes for which a quorum agrees + // agreeIndexes == {index \in 1..Len(log[i]) : Agree(index) \in Quorum} + // \* New value for commitIndex'[i] + // newCommitIndex == IF /\ agreeIndexes /= {} + // /\ log[i][Max(agreeIndexes)].term = currentTerm[i] + // THEN Max(agreeIndexes) + // ELSE commitIndex[i] + // IN commitIndex' = [commitIndex EXCEPT ![i] = newCommitIndex] + std::vector stored; + std::vector promised; + for (auto [peer, replication_state] : leader_.followers) { + if (membership().voting_replicas().count(peer)) { + stored.push_back(replication_state.match_idx); + promised.push_back(replication_state.promise_idx); + } + } + // Sort descending, so that stored[N-1] contains the highest index + // agreed upon by N replicas. + stored.push_back(log_->last_idx()); + std::sort(stored.begin(), stored.end(), [](uint64_t a, uint64_t b){ return a > b; }); + LogIdx new_promise = stored[membership_quorum_size()-1]; // -1 because zero-indexed + bool changed = false; + if (new_promise > promise_idx_) { + LOG(VERBOSE) << "promising logs " << promise_idx_ << " to " << new_promise; + COUNTER(raft, logs_promised)->IncrementBy(new_promise - promise_idx_); + promise_idx_ = new_promise; + GAUGE(raft, promise_index)->Set(promise_idx_); + changed = true; + } + // Don't push promise_idx_ until here, because we may update it above. + // This matters for size-1 raft groups. + promised.push_back(promise_idx_); + std::sort(promised.begin(), promised.end(), [](uint64_t a, uint64_t b){ return a > b; }); + LogIdx new_commit = promised[membership_quorum_size()-1]; // -1 because zero-indexed + if (new_commit > commit_idx_) { + LOG(VERBOSE) << "committing logs " << commit_idx_ << " to " << new_commit; + COUNTER(raft, logs_committed)->IncrementBy(new_commit - commit_idx_); + commit_idx_ = new_commit; + GAUGE(raft, commit_index)->Set(commit_idx_); + // Committing the log has the potential to commit a previously uncomitted + // membership; check that: + MaybeChangeUncommittedMembershipsBasedOnLog(); + changed = true; + } + if (changed) { + // The following line departs slightly from the Raft protocol, erring + // on sending more remote messages in order to keep Raft followers more + // up to date with the LEADER's commit. In stock Raft, the leader is + // the only member of the replica group whose database commits "matter" + // in terms of latency. For example, in an otherwise quiescent cluster, + // if the leader gets a write, it will send that write to followers, get + // back acknowledgements, then commit it locally. But followers won't + // hear about that commit until the leader's next heartbeat, which for + // us is >= 1 tick and could be ~1s or more. Also for us, commits matter + // to followers, since we serve client requests from all replicas, + // and we serve those requests by watching the commit log. + // In practice in an active cluster, this should actually not send + // any more messages than normal, since our (also non-Raft-standard) + // `inflight` stops us from sending out an additional heartbeat to a + // follower while an existing AppendEntries is in flight, and with cluster + // activity we should expect a new log to appear at or before when we + // would clear `inflight` and actually send this heartbeat. But this + // makes understanding and testing out cluster activity much easier, and + // in cases where we do have lulls in traffic, it should keep client latency + // low. + // + // TLDR: when we update commits, we queue up a send_heartbeat for + // all followers in order to allow them to advance their commits without + // waiting for the next TimerTick. + for (auto iter = leader_.followers.begin(); iter != leader_.followers.end(); ++iter) { + iter->second.send_heartbeat = true; + } + } +} + +// +// \* Message handlers +// \* i = recipient, j = sender, m = message +// + +// \* Server i receives a RequestVote request from server j with +// \* m.mterm <= currentTerm[i]. +void Raft::HandleVoteRequest(context::Context* ctx, const TermId& msg_term, const VoteRequest& msg, const peerid::PeerID& from) { + // HandleRequestVoteRequest(i, j, m) == + LogIdx last_log_idx = log_->last_idx(); + TermId last_log_term = log_->last_term(); + // LET logOk == + // \/ m.mlastLogTerm > LastTerm(log[i]) + // \/ /\ m.mlastLogTerm = LastTerm(log[i]) + // /\ m.mlastLogIndex >= Len(log[i]) + bool log_ok = + msg.last_log_term() > last_log_term || ( + msg.last_log_term() == last_log_term && + msg.last_log_idx() >= last_log_idx); + // LET grant == + // /\ m.mterm = currentTerm[i] + // /\ logOk + // /\ votedFor[i] \in {Nil, j} + bool grant = + msg_term == current_term_ && + log_ok && + (!voted_for_.has_value() || *voted_for_ == from); + // IN /\ m.mterm <= currentTerm[i] + if (msg_term > current_term_) { return; } + // /\ \/ grant /\ votedFor' = [votedFor EXCEPT ![i] = j] + // \/ ~grant /\ UNCHANGED votedFor + if (grant) { + voted_for_ = from; + LOG(INFO) << "granted vote at " << current_term_ << " with " << last_log_idx << " at " << last_log_term << " for node " << from << " with " << msg.last_log_idx() << " at " << msg.last_log_term(); + // if we're a follower, reset our election ticks. + follower_.election = RandomElectionTimeout(); + } else if (msg_term != current_term_) { + LOG(INFO) << "ignored vote request with " << msg_term << " < current " << current_term_; + } else if (voted_for_.has_value()) { + LOG(INFO) << "rejected vote at " << current_term_ << " for node " << from << " as already voted for " << voted_for_->DebugString(); + } else { + LOG(INFO) << "rejected vote at " << current_term_ << " with " << last_log_idx << " at " << last_log_term << " for node " << from << " at " << msg_term << " with " << msg.last_log_idx() << " at " << msg.last_log_term(); + } + // /\ Reply([ + auto resp = ctx->Protobuf(); + resp->set_group(group_); + // mterm |-> currentTerm[i], + resp->set_term(current_term_); + // mtype |-> RequestVoteResponse, + auto vote_resp = resp->mutable_vote_response(); + // mvoteGranted |-> grant, + vote_resp->set_vote_granted(grant); + AddSendableMessage(SendableRaftMessage::Reply(from, resp)); +} + +// \* Server i receives a RequestVote response from server j with +// \* m.mterm = currentTerm[i]. +void Raft::HandleVoteResponse(context::Context* ctx, const TermId& msg_term, const VoteResponse& msg, const peerid::PeerID& from) { + // HandleRequestVoteResponse(i, j, m) == + // /\ m.mterm = currentTerm[i] + if (msg_term != current_term_) { return; } + if (role_ != internal::Role::CANDIDATE) { return; } + if (msg.vote_granted()) { + if (membership().voting_replicas().count(from)) { + // /\ \/ /\ m.mvoteGranted + // /\ votesGranted' = [votesGranted EXCEPT ![i] = votesGranted[i] \cup {j}] + candidate_.votes_granted.insert(from); + MELOG(VERBOSE) << "accepted vote from " << from; + } else { + MELOG(VERBOSE) << "ignored vote from non-voting member " << from; + } + } else { + // \/ /\ ~m.mvoteGranted /\ UNCHANGED <> + MELOG(INFO) << "received vote rejected from " << from << " at " << current_term_; + } +} + +// \* Server i receives an AppendEntries request from server j with +// \* m.mterm <= currentTerm[i]. This just handles m.entries of length 0 or 1, but +// \* implementations could safely accept more by treating them the same as +// \* multiple independent requests of 1 entry. +void Raft::HandleAppendRequest(context::Context* ctx, const TermId& msg_term, const AppendRequest& msg, const peerid::PeerID& from) { + uint64_t prev_log_idx = msg.prev_log_idx(); + uint64_t msg_prev_log_term = msg.prev_log_term(); + uint64_t our_prev_log_term = log_->At(msg.prev_log_idx()).Term(); + // LET logOk == \/ m.mprevLogIndex = 0 + // \/ /\ m.mprevLogIndex > 0 /\ m.mprevLogIndex <= Len(log[i]) /\ m.mprevLogTerm = log[i][m.mprevLogIndex].term + bool log_ok = prev_log_idx == 0 || msg_prev_log_term == our_prev_log_term; + + // IN /\ m.mterm <= currentTerm[i] + // /\ \/ \* return to follower state + if (msg_term > current_term_) { return; } + + if (msg_term == current_term_) { + // /\ m.mterm = currentTerm[i] + switch (role_) { + case internal::Role::CANDIDATE: { + // /\ state[i] = Candidate + // /\ state' = [state EXCEPT ![i] = Follower] + SetRole(internal::Role::FOLLOWER); + follower_ = { + .leader = from, + .election = RandomElectionTimeout(), + }; + MELOG(INFO) << "dropped candidacy, became follower at " << current_term_ << " of " << from; + } break; + case internal::Role::FOLLOWER: + if (!follower_.leader.has_value()) { + MELOG(INFO) << "became follower at " << current_term_ << " of " << from; + } + follower_.leader = from; + follower_.election = RandomElectionTimeout(); + break; + case internal::Role::LEADER: + return; + } + } + // \/ /\ \* reject request + // \/ m.mterm < currentTerm[i] + // \/ /\ m.mterm = currentTerm[i] + // /\ state[i] = Follower + // /\ \lnot logOk + if (msg_term < current_term_ || ( + msg_term == current_term_ && + role_ == internal::Role::FOLLOWER && + !log_ok)) { + LogIdx our_last_idx = log_->last_idx(); + if (msg_term < current_term_) { + LOG(INFO) << "ignored message with " << msg_term << " < current " << current_term_; + } else if (our_prev_log_term > 0) { + LOG(WARNING) << "rejected append from " << from << " with " << prev_log_idx << " at " << msg_prev_log_term << ", we have " << our_prev_log_term; + } else { + LOG(INFO) << "rejected append from " << from << " with " << prev_log_idx << ", we are behind at " << our_last_idx; + } + + // /\ Reply([ + auto out = ctx->Protobuf(); + out->set_group(group_); + // mterm |-> currentTerm[i], + out->set_term(current_term_); + // mtype |-> AppendEntriesResponse, + auto append = out->mutable_append_response(); + // msuccess |-> FALSE, + append->set_success(false); + // mmatchIndex |-> 0, + // We send our commit index as the last index we know we matched. If we committed + // up to a point in time, we know we match with the rest of the Raft group up to + // that index, so this should be safe. + append->set_match_idx(commit_idx_); + append->set_last_log_idx(our_last_idx); + append->set_promise_idx(promise_idx_); + AddSendableMessage(SendableRaftMessage::Reply(from, out)); + return; + } + // \/ \* accept request + // /\ m.mterm = currentTerm[i] + // /\ state[i] = Follower + // /\ logOk + // ... and the TLA+ that follows doesn't correspond to procedural code well + // find point of log conflict + CHECK(msg_term == current_term_); + CHECK(role_ == internal::Role::FOLLOWER); + CHECK(log_ok); + uint64_t last_processed_idx = prev_log_idx; + for (int i = 0; i < msg.entries_size(); i++) { + uint64_t msg_entry_log_idx = prev_log_idx + i + 1; + const LogEntry& msg_entry = msg.entries(i); + TermId our_idx_term = log_->At(msg_entry_log_idx).Term(); + if (our_idx_term != 0 && our_idx_term != msg_entry.term()) { + if (msg_entry_log_idx <= commit_idx_) { + LOG(WARNING) << "mismatch prior to commit: " << msg_entry_log_idx << " <= " << commit_idx_; + break; + } else if (msg_entry_log_idx <= promise_idx_) { + LOG(WARNING) << "mismatch prior to promise: " << msg_entry_log_idx << " <= " << promise_idx_; + break; + } else if (error::OK != log_->CancelFrom(msg_entry_log_idx)) { + LOG(WARNING) << "failed to cancel logs from " << msg_entry_log_idx; + break; + } + // CancelFrom(msg_entry_log_idx) has the potential to chop off an + // uncommitted membership from the end of the log; check that: + MaybeChangeUncommittedMembershipsBasedOnLog(); + // If this succeeds, the next if statement should always be true. + } + LogIdx last = log_->last_idx(); + if (msg_entry_log_idx == last + 1) { + auto next_hash = NextHash(msg_entry); + if (!util::ConstantTimeEquals(next_hash, msg_entry.hash_chain())) { + LOG(WARNING) << "failed to append log: hash chain mismatch at " << msg_entry_log_idx; + break; + } + auto [loc, err] = LogAppend(msg_entry); + if (err != error::OK) { + LOG(WARNING) << "failed to append log " << msg_entry_log_idx; + break; + } else { + LOG(VERBOSE) << "appended log index " << msg_entry_log_idx; + } + } + last_processed_idx = msg_entry_log_idx; + } + + LogIdx leader_commit = std::min(msg.leader_commit(), last_processed_idx); + LogIdx leader_promise = std::min(msg.leader_promise(), last_processed_idx); + LOG(DEBUG) << "commit=" << leader_commit << " lcommit=" << msg.leader_commit() + << " promise=" << leader_promise << " lpromise=" << msg.leader_promise() + << " last=" << last_processed_idx; + // TLA+... and we're back! + // /\ commitIndex' = [commitIndex EXCEPT ![i] = m.mcommitIndex] + if (leader_commit > commit_idx_) { + LOG(VERBOSE) << "committed transactions from " << commit_idx_ << " to " << leader_commit; + COUNTER(raft, logs_committed)->IncrementBy(leader_commit - commit_idx_); + commit_idx_ = leader_commit; + GAUGE(raft, commit_index)->Set(commit_idx_); + // Updating the commit index has the potential to commit an uncommitted + // membership; check that: + MaybeChangeUncommittedMembershipsBasedOnLog(); + } + if (leader_promise > promise_idx_) { + LOG(VERBOSE) << "promised transactions from " << promise_idx_ << " to " << leader_promise; + COUNTER(raft, logs_promised)->IncrementBy(leader_promise - promise_idx_); + promise_idx_ = leader_promise; + GAUGE(raft, promise_index)->Set(promise_idx_); + } + + auto out = ctx->Protobuf(); + // /\ Reply([ + out->set_group(group_); + // mterm |-> currentTerm[i], + out->set_term(current_term_); + // mtype |-> AppendEntriesResponse, + auto append = out->mutable_append_response(); + // msuccess |-> TRUE, + append->set_success(true); + // mmatchIndex |-> m.mprevLogIndex + Len(m.mentries), + append->set_match_idx(last_processed_idx); + append->set_promise_idx(promise_idx_); + append->set_last_log_idx(log_->last_idx()); + AddSendableMessage(SendableRaftMessage::Reply(from, out)); +} + +void Raft::MaybeChangeUncommittedMembershipsBasedOnLog() { + bool changed = false; + // We may have committed some of the previously uncommitted membership + // changes by moving the commit index forward; pop them off the front. + while (uncommitted_memberships_.size() > 0 + && uncommitted_memberships_.front().first <= commit_idx_) { + auto f = uncommitted_memberships_.begin(); + LOG(VERBOSE) << "promoting committed membership at " << f->first; + membership_ = std::move(f->second); + uncommitted_memberships_.pop_front(); + changed = true; + } + // We may have rolled back the log via CancelFrom, chopping some + // uncommitted memberships off the back. Remove them. + while (uncommitted_memberships_.size() > 0 + && uncommitted_memberships_.back().first > log_->last_idx()) { + LOG(VERBOSE) << "discarding Uncommitted membership at " << uncommitted_memberships_.back().first; + uncommitted_memberships_.pop_back(); + changed = true; + } + // If we've changed our uncommitted memberships in any way + // that may have affected the active membership we should use, + // handle those changes. + if (changed) { HandleMembershipChange(); } +} + +// \* Server i receives an AppendEntries response from server j with +// \* m.mterm = currentTerm[i]. +void Raft::HandleAppendResponse(context::Context* ctx, const TermId& msg_term, const AppendResponse& msg, const peerid::PeerID& from) { + // HandleAppendEntriesResponse(i, j, m) == + // /\ m.mterm = currentTerm[i] + if (msg_term != current_term_) { return; } + if (role_ != internal::Role::LEADER) { return; } + if (leader_.followers.count(from) == 0) { return; } + internal::ReplicationState& replication = leader_.followers[from]; + if (msg.success()) { + // /\ \/ /\ m.msuccess \* successful + if (replication.inflight.has_value() && msg.match_idx() >= (*replication.inflight)) { + replication.inflight.reset(); + } + if (msg.match_idx() + 1 > replication.next_idx) { + // /\ nextIndex' = [nextIndex EXCEPT ![i][j] = m.mmatchIndex + 1] + replication.next_idx = msg.match_idx() + 1; + } + if (msg.match_idx() > replication.match_idx) { + // /\ matchIndex' = [matchIndex EXCEPT ![i][j] = m.mmatchIndex] + replication.match_idx = msg.match_idx(); + } + if (msg.promise_idx() > replication.promise_idx) { + replication.promise_idx = msg.promise_idx(); + } + replication.send_probe = false; + return; + } + // \/ /\ \lnot m.msuccess \* not successful + if (replication.send_probe) { + LOG(VERBOSE) << "received probe append rejection at " << replication.next_idx << " from " << from << " having " << msg.last_log_idx(); + } else { + LOG(INFO) << "received append rejection at " << replication.next_idx << " from " << from << " having " << msg.last_log_idx(); + } + // /\ nextIndex' = [nextIndex EXCEPT ![i][j] = Max({nextIndex[i][j] - 1, 1})] + replication.next_idx = std::max( + msg.match_idx() + 1, + std::min( + replication.next_idx - 1, + msg.last_log_idx() + 1)); + replication.send_probe = true; + replication.inflight.reset(); + uint64_t chunk_size_remaining = config_.replication_chunk_bytes(); + const uint64_t overflow = (uint64_t(0)) - 1; + for (uint64_t next_idx = replication.next_idx - 1; next_idx != overflow; next_idx--) { + if (next_idx <= msg.match_idx()) { break; } + size_t log_entry_size = log_->At(replication.next_idx).SerializedSize(); + if (log_entry_size > chunk_size_remaining) { break; } + chunk_size_remaining -= log_entry_size; + replication.next_idx = next_idx; + } +} + +void Raft::AddSendableMessage(SendableRaftMessage msg) { + if (msg.to().has_value()) { + // Make sure we're not looping messages + CHECK(*msg.to() != me_); + } else { + // Don't bother adding a broadcast message if we're the only one in the group. + if (membership().all_replicas().size() == 1 && membership().all_replicas().count(me_)) { + return; + } + } + sendable_messages_.push_back(msg); +} + +// \* Any RPC with a newer term causes the recipient to advance its term first. +void Raft::UpdateTerm(const peerid::PeerID& from, const RaftMessage& msg) { + // UpdateTerm(i, j, m) == + // /\ m.mterm > currentTerm[i] + if (msg.term() <= current_term_) { return; } + LOG(INFO) << "becoming follower at " << msg.term() << " (from " << current_term_ << ") due to message from " << from; + // /\ currentTerm' = [currentTerm EXCEPT ![i] = m.mterm] + COUNTER(raft, term_updated)->Increment(); + COUNTER(raft, term_increments)->IncrementBy(msg.term() - current_term_); + current_term_ = msg.term(); + GAUGE(raft, current_term)->Set(current_term_); + // /\ state' = [state EXCEPT ![i] = Follower] + util::Ticks new_election_ticks; + switch (role_) { + case internal::Role::FOLLOWER: + new_election_ticks = follower_.election; + break; + case internal::Role::CANDIDATE: + new_election_ticks = candidate_.election; + break; + default: // LEADER + new_election_ticks = RandomElectionTimeout(); + } + SetRole(internal::Role::FOLLOWER); + follower_ = { + .election = new_election_ticks, + }; + // /\ votedFor' = [votedFor EXCEPT ![i] = Nil] + voted_for_.reset(); +} + +// \* Responses with stale terms are ignored. +bool Raft::ShouldDropResponseDueToStaleTerm(const peerid::PeerID& from, const RaftMessage& msg) { + // DropStaleResponse(i, j, m) == + // /\ m.mterm < currentTerm[i] + if (msg.term() < current_term_) { + // /\ Discard(m) + LOG(INFO) << "ignoring message with " << msg.term() << " < current " << current_term_ << " from " << from; + return true; + } + return false; +} + +// #* Receive a message. +void Raft::Receive(context::Context* ctx, const RaftMessage& msg, const peerid::PeerID& from) { + // Receive(m) == + if (msg.group() != group_) { + LOG(ERROR) << "received raft message from " << from << " for wrong group " << msg.group(); + return; + } + if (membership().all_replicas().count(from) == 0) { + LOG(ERROR) << "message from non-peer " << from; + return; + } + // IN \* Any RPC with a newer term causes the recipient to advance + // \* its term first. Responses with stale terms are ignored. + // \/ UpdateTerm(i, j, m) + UpdateTerm(from, msg); + if (role_ == internal::Role::LEADER) { + auto f = leader_.followers.find(from); + if (f != leader_.followers.end()) { + f->second.last_seen_ticks = 0; + } + } + switch (msg.inner_case()) { + case RaftMessage::kVoteRequest: + // \/ /\ m.mtype = RequestVoteRequest + // /\ HandleRequestVoteRequest(i, j, m) + COUNTER(raft, vote_requests_received)->Increment(); + LOG(VERBOSE) << "HandleVoteRequest"; + HandleVoteRequest(ctx, msg.term(), msg.vote_request(), from); + break; + case RaftMessage::kVoteResponse: + // \/ /\ m.mtype = RequestVoteResponse + // /\ \/ DropStaleResponse(i, j, m) + COUNTER(raft, vote_responses_received)->Increment(); + if (ShouldDropResponseDueToStaleTerm(from, msg)) { break; } + // \/ HandleRequestVoteResponse(i, j, m) + LOG(VERBOSE) << "HandleVoteResponse"; + HandleVoteResponse(ctx, msg.term(), msg.vote_response(), from); + break; + case RaftMessage::kAppendRequest: + // \/ /\ m.mtype = AppendEntriesRequest + // /\ HandleAppendEntriesRequest(i, j, m) + COUNTER(raft, append_requests_received)->Increment(); + LOG(VERBOSE) << "HandleAppendRequest"; + HandleAppendRequest(ctx, msg.term(), msg.append_request(), from); + break; + case RaftMessage::kAppendResponse: + // \/ /\ m.mtype = AppendEntriesResponse + // /\ \/ DropStaleResponse(i, j, m) + COUNTER(raft, append_responses_received)->Increment(); + if (ShouldDropResponseDueToStaleTerm(from, msg)) { break; } + // \/ HandleAppendEntriesResponse(i, j, m) + LOG(VERBOSE) << "HandleAppendResponse"; + HandleAppendResponse(ctx, msg.term(), msg.append_response(), from); + break; + case RaftMessage::kTimeoutNow: + COUNTER(raft, timeout_nows_received)->Increment(); + if (ShouldDropResponseDueToStaleTerm(from, msg)) { break; } + LOG(VERBOSE) << "TimeoutNow"; + ElectionTimeout(ctx); + break; + case RaftMessage::INNER_NOT_SET: + COUNTER(raft, invalid_requests_received)->Increment(); + LOG(ERROR) << "unhandled message case from " << from; + break; + } + MaybeChangeStateAndSendMessages(ctx); +} + +void Raft::MaybeChangeStateAndSendMessages(context::Context* ctx) { + MaybeBecomeLeader(ctx); + MaybeAdvanceCommitIndex(); + for (auto peer : membership().all_replicas()) { + if (peer == me_) { continue; } + AppendEntries(ctx, peer); + } + if (role_ == internal::Role::LEADER && leader_.relinquishing) { + TryToRelinquishLeadership(ctx); + } +} + +void Raft::TryToRelinquishLeadership(context::Context* ctx) { + bool relinquishing = false; + peerid::PeerID next; + for (auto peer : membership().voting_replicas()) { + if (peer == me_) { continue; } + auto iter = leader_.followers.find(peer); + if (iter == leader_.followers.end()) { continue; } + const internal::ReplicationState& state = iter->second; + if (state.match_idx == log_->last_idx()) { + next = peer; + relinquishing = true; + break; + } + } + if (relinquishing) { + // Finally, a worthy successor. Request that it immediately execute an election timeout to + // become the new leader. + LOG(INFO) << "Relinquishing leadership to " << next; + auto msg = ctx->Protobuf(); + msg->set_group(group_); + msg->set_term(current_term_); + msg->set_timeout_now(true); + AddSendableMessage(SendableRaftMessage::Reply(next, msg)); + // We've relinquished leadership; become a follower at our current term. + SetRole(internal::Role::FOLLOWER); + follower_ = { + .election = RandomElectionTimeout(), + }; + } +} + +util::Ticks Raft::RandomElectionTimeout() const { + return config_.election_ticks() + rand() % config_.election_ticks(); +} + +std::string MsgStr(const RaftMessage& msg) { + std::stringstream ss; + ss << "group:" << msg.group() << " term:" << msg.term(); + switch (msg.inner_case()) { + case RaftMessage::kVoteRequest: { + auto m = msg.vote_request(); + ss << " vote_request:{ last_log_idx:" << m.last_log_idx() << " last_log_term:" << m.last_log_term() << " }"; + } break; + case RaftMessage::kVoteResponse: { + auto m = msg.vote_response(); + ss << " vote_response:{ vote_granted:" << m.vote_granted() << " }"; + } break; + case RaftMessage::kAppendRequest: { + auto m = msg.append_request(); + ss << " append_request:{ prev_log_idx:" << m.prev_log_idx() + << " prev_log_term:" << m.prev_log_term() + << " leader_commit:" << m.leader_commit() + << " leader_promise:" << m.leader_promise(); + for (int i = 0; i < m.entries_size(); i++) { + ss << " entries:{ term=" << m.entries(i).term() << " }"; + } + ss << " }"; + } break; + case RaftMessage::kAppendResponse: { + auto m = msg.append_response(); + ss << " append_response:{ success:" << m.success() + << " match_idx:" << m.match_idx() + << " promise_idx:" << m.promise_idx() + << " last_log_idx:" << m.last_log_idx() << " }"; + } break; + case RaftMessage::kTimeoutNow: { + ss << " timeout_now:" << msg.timeout_now(); + } break; + case RaftMessage::INNER_NOT_SET: + ss << "INNER_NOT_SET"; + break; + } + return ss.str(); +} + +const Membership& Raft::membership() const { + // If we have any uncommitted memberships, use the most recent. + if (uncommitted_memberships_.size()) { + return *uncommitted_memberships_.back().second; + } + // If not, use the canonical last-committed one. + return *membership_; +} + +util::Ticks Raft::last_seen_ticks(const peerid::PeerID& follower) const { + if (role_ != internal::Role::LEADER) { return util::InvalidTicks; } + auto f = leader_.followers.find(follower); + if (f == leader_.followers.end()) { return util::InvalidTicks; } + return f->second.last_seen_ticks; +} + +// Precondition: follower_id is the id of a peer in the list of followers +// and is not this core's id. +error::Error Raft::FollowerReplicationStatus(const peerid::PeerID& follower, EnclavePeerReplicationStatus* status) const { + CHECK(role_ == internal::Role::LEADER); + auto f = leader_.followers.find(follower); + if(f == leader_.followers.end()) { + return error::OK; + } + status->set_next_index(f->second.next_idx); + status->set_match_index(f->second.match_idx); + if(f->second.inflight.has_value()) { + status->set_inflight_index(f->second.inflight.value()); + } + status->set_probing(f->second.send_probe); + return error::OK; +} + +} // namespace svr2::raft diff --git a/enclave/raft/raft.h b/enclave/raft/raft.h new file mode 100644 index 0000000..37c9759 --- /dev/null +++ b/enclave/raft/raft.h @@ -0,0 +1,289 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_RAFT_RAFT_H__ +#define __SVR2_RAFT_RAFT_H__ + +#include +#include +#include +#include +#include "peerid/peerid.h" +#include "proto/error.pb.h" +#include "proto/raft.pb.h" +#include "proto/msgs.pb.h" +#include "util/ticks.h" +#include "raft/types.h" +#include "raft/log.h" +#include "raft/internal.h" +#include "raft/membership.h" +#include "context/context.h" + +namespace svr2::raft { + +// MsgStr returns a debug string of the contents of [msg]. +std::string MsgStr(const RaftMessage& msg); + +// SendableRaftMessage wraps a message that should be sent out to one +// or a set of other Raft instances. Messages are either broadcast, +// which should be sent to all `peers()` of the Raft instance, or +// targetted, which should be sent to a single instance. If targetted, +// `to().has_value()` will be true. +class SendableRaftMessage { + public: + static SendableRaftMessage Broadcast(RaftMessage* msg) { + return SendableRaftMessage(msg, std::optional()); + } + static SendableRaftMessage Reply(const peerid::PeerID& to, RaftMessage* msg) { + return SendableRaftMessage(msg, to); + } + const RaftMessage& message() { return *message_; } + // If `!to().has_value()`, this is a broadcast message and should be + // sent to all `raft.peers()`. + const std::optional& to() { return to_; } + private: + SendableRaftMessage(RaftMessage* msg, std::optional t) + : message_(msg), to_(t) {} + RaftMessage* message_; + // In the original rust Raft code, this was `from`, in a `Reply` enum. + std::optional to_; +}; + +// Raft provides an implementation of the Raft protocol. +// This implementation is not safe for concurrent access. +// +// These are the major functions used to do Raft-y things: +// +// Requesting and receiving actual log entries: +// ClientRequest - request that an entry be added to the log +// TakeCommittedLog - return the next log entry that's been committed +// Internal Raft stuff: +// Receive - receive a message from another Raft +// SendableMessages - get any messages to send to other Raft +// TimerTick - the inevitable march of time +// +// A creator of this class MUST: +// +// - regularly call TimerTick +// - call Receive whenever another Raft sends it a message +// - regularly call TakeCommittedLog and process the output +// - call SendableMessages after each call that takes a context::Context* and +// consume the results before that context falls out of scope. +// +// A creator of this class MAY: +// +// - call ClientRequest to request that entries be appended to the canonical log +class Raft { + public: + DELETE_COPY_AND_ASSIGN(Raft); + Raft( + GroupId group, + const peerid::PeerID& me, + std::unique_ptr membership, + std::unique_ptr log, + const enclaveconfig::RaftConfig& config, + bool committed_log, + size_t super_majority); + + // Simple getters + + // group_id contains a unique group identifier that allows the Raft instance + // to make sure it's not accidentally talking to a different set of raft + // servers than it thinks it is. That way madness lies. + const GroupId& group_id() const { return group_; } + // last_applied returns the log index of the last log that has been requested, + // but which may or may not be committed. + const LogIdx& last_applied() const { return last_applied_; } + // commit_idx returns the log index of the last committed log. This should + // monotonically increase. + const LogIdx& commit_idx() const { return commit_idx_; } + // log returns a const reference to this Raft's underlying Log. Note that Log + // is not safe for concurrent access, so should not be accessed concurretly + // with function calls on this Raft. + const Log& log() const { return *log_; } + // all_replicas returns the set of peer IDs for other members of this Raft's group. + // It not contain the ID for this Raft. + const std::set& all_replicas() const { return membership().all_replicas(); } + const std::set peers() const { + std::set out = membership().all_replicas(); + out.erase(me_); + return out; + } + const Membership& committed_membership() const { return *membership_; } + const peerid::PeerID& me() const { return me_; } + // is_leader returns true when this Raft thinks it is the leader of the + // Raft group. + bool is_leader() const { return role_ == internal::Role::LEADER; } + // leader returns the suspected current leader of this Raft group. + std::optional leader() const; + // current_term returns the current Raft term. + const TermId& current_term() const { return current_term_; } + // quorum_size returns the size of the smallest majority among this + // Raft and its voting peers. + static size_t quorum_size(size_t voting_replicas, size_t super_majority); + size_t membership_quorum_size() const; + // voting() returns true if we believe we are a voting member of the current + // replica group. + bool voting() const { return membership().voting_replicas().count(me_); } + // If this is the leader, return the number of ticks ago when we saw a + // message from the given follower. If not leader or follower not found, + // returns InvalidTicks. + util::Ticks last_seen_ticks(const peerid::PeerID& follower) const; + const enclaveconfig::RaftConfig& config() const { return config_; } + error::Error FollowerReplicationStatus(const peerid::PeerID& follower, EnclavePeerReplicationStatus* status) const; + + // Simple setters + void set_replication_chunk_size(size_t s) { config_.set_replication_chunk_bytes(s); } + + // More complicated functions. For each function that takes a context::Context, + // SendableMessages() must be called after that function completes and before + // that context falls out of scope. + + // Request that a log entry containing the given data be added to the Raft log. + // Requires that `is_leader()` is true. + // + // If successful, this log returns the location where the log _may_ be + // committed. You can tell if the log was successfully added if TakeCommittedLog + // returns a log entry with a matching location (term+idx). + std::pair ClientRequest(context::Context* ctx, const std::string& data); + // Request that a new replica group configuration be adopted by the Raft + // group. Requires that `is_leader()` is true, and that the configuration + // is an acceptable next configuration from the current one. + std::pair ReplicaGroupChange(context::Context* ctx, const ReplicaGroup& g); + // Receive a Raft message from another replica. + // Send messages from SendableMessages after this call. + void Receive(context::Context* ctx, const RaftMessage& msg, const peerid::PeerID& from); + // Tick the timer. This code currently treats each call to this function + // as a single tick. Note that this does not currently correlate at all + // with any real-time measure (it's not a second, per se). + // Send messages from SendableMessages after this call. + void TimerTick(context::Context* ctx); + // ResetPeer lets this Raft instance know that the given peer ID + // may have lost some of the messages we sent to it previously. + void ResetPeer(context::Context* ctx, const peerid::PeerID& id); + // Reconfigure sets the RaftConfig to a new value. + void Reconfigure(const enclaveconfig::RaftConfig& config); + // If I'm the leader, attempt to pawn that responsibility off on someone else. + void RelinquishLeadership(context::Context* ctx); + + // Return the list of messages that should be sent to other peers. + std::vector SendableMessages() { return std::move(sendable_messages_); } + + // Pop the next committed log entry off the list, if there is one. + // On success, LogIdx will be nonzero and LogEntry will be filled in + // If there is no new committed log, LogIdx will be zero and LogEntry + // will be empty. + std::pair TakeCommittedLog(); + + const Membership& membership() const; + + private: + void set_heartbeat_timeout(util::Ticks t); + void set_election_timeout(util::Ticks t); + + // MaybeBecomeLeader sometimes wants to append a log entry. This call + // allows it to do so without recursing MaybeChangeStateAndSendMessages. + std::pair ClientRequestInternal(LogEntry* entry); + // Set role and clear all current role state. + void SetRole(internal::Role r); + // Get the current leader as understood by this Raft, if there is one. + std::optional Leader() const; + // Append the given entry to the log. + std::pair LogAppend(const LogEntry& entry); + // Called by TimerTick() when an election timeout occurs to start a new election. + void ElectionTimeout(context::Context* ctx); + // Returns a new #ticks to wait before the next election, randomly (weak) in range + // [election_timeout_, election_timeout_*2) + util::Ticks RandomElectionTimeout() const; + // Any RPC wiht a newer term causes the recipient to advance its term first. + void UpdateTerm(const peerid::PeerID& peer, const RaftMessage& msg); + // Returns true (and logs) if the given message should be dropped due to its term + // being stale. + bool ShouldDropResponseDueToStaleTerm(const peerid::PeerID& from, const RaftMessage& msg); + // Check for any state changes that may require us to send more messages. + void MaybeChangeStateAndSendMessages(context::Context* ctx); + // Check if, as a candidate, we have enough to become the leader. + // Part of MaybeChangeStateAndSendMessages. + void MaybeBecomeLeader(context::Context* ctx); + // Check if, given the information we have, we can advance the commit index. + // Part of MaybeChangeStateAndSendMessages. + void MaybeAdvanceCommitIndex(); + // Try to find a worthy replica to take over as leader. If one is found, + // send it a timeout_now to become the new leader. + void TryToRelinquishLeadership(context::Context* ctx); + // Set uncommitted membership on leader. + void AddUncommittedMembership(LogIdx idx, std::unique_ptr membership); + void HandleMembershipChange(); + // See if uncommitted membership is now committed, and if so make it the + // canonical one. + void MaybeChangeUncommittedMembershipsBasedOnLog(); + // on uncommitted logs. + // If leader, send message to peer_id requesting that they append entries to + // their log. + // Part of MaybeChangeStateAndSendMessages. + void AppendEntries(context::Context* ctx, const peerid::PeerID& peer); + // Get the next hash for the next log entry. + std::array NextHash(const LogEntry& next_entry); + + // Request handlers + void HandleVoteRequest(context::Context* ctx, const TermId& msg_term, const VoteRequest& msg, const peerid::PeerID& from); + void HandleVoteResponse(context::Context* ctx, const TermId& msg_term, const VoteResponse& msg, const peerid::PeerID& from); + void HandleAppendRequest(context::Context* ctx, const TermId& msg_term, const AppendRequest& msg, const peerid::PeerID& from); + void HandleAppendResponse(context::Context* ctx, const TermId& msg_term, const AppendResponse& msg, const peerid::PeerID& from); + + void AddSendableMessage(SendableRaftMessage msg); + + // Message to request a vote for myself. Requires role==candidate. + RaftMessage* RequestVoteMessage(context::Context* ctx); + + GroupId group_; + peerid::PeerID me_; + std::unique_ptr membership_; + // uncommitted_memberships_ keeps an ordered list of the uncommitted-but- + // active memberships based on the log. We're effectively certain that + // once a full request (AppendEntries, etc) is complete, this should have + // exactly zero or one element in it, and thus can probably be not-a-list. + std::list>> uncommitted_memberships_; + + enclaveconfig::RaftConfig config_; + + LogIdx last_applied_; + + // \* The server's term number. + // VARIABLE currentTerm + TermId current_term_; + + // \* The candidate the server voted for in its current term, or + // \* Nil if it hasn't voted for any. + // VARIABLE votedFor + std::optional voted_for_; + + // \* The server's state (Follower, Candidate, or Leader). + // VARIABLE state + internal::Role role_; + internal::FollowerState follower_; + internal::CandidateState candidate_; + internal::LeaderState leader_; + + // \* A Sequence of log entries. The index into this sequence is the index of the + // \* log entry. Unfortunately, the Sequence module defines Head(s) as the entry + // \* with index 1, so be careful not to use that! + // VARIABLE log + std::unique_ptr log_; + + // \* The index of the latest entry in the log the state machine may apply. + // VARIABLE commitIndex + LogIdx commit_idx_; + // We promise to commit at the given index; we will not truncate our log past + // this point. + LogIdx promise_idx_; + + // The list of messages that are generated to send out based on various actions. + std::vector sendable_messages_; + + size_t super_majority_; +}; + +} // namespace svr2::raft + +#endif // __SVR2_RAFT_RAFT_H__ diff --git a/enclave/raft/tests/log.cc b/enclave/raft/tests/log.cc new file mode 100644 index 0000000..cc29e6f --- /dev/null +++ b/enclave/raft/tests/log.cc @@ -0,0 +1,109 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +//TESTDEP gtest +//TESTDEP peerid +//TESTDEP context +//TESTDEP sip +//TESTDEP sender +//TESTDEP env +//TESTDEP env/test +//TESTDEP env +//TESTDEP util +//TESTDEP metrics +//TESTDEP proto +//TESTDEP protobuf-lite +//TESTDEP libsodium + +#include +#include +#include "raft/log.h" +#include "env/env.h" + +namespace svr2::raft { + +class LogTest : public ::testing::Test { + protected: + static void SetUpTestCase() { + env::Init(); + } +}; + +TEST_F(LogTest, BasicUsage) { + Log log(1<<20); + EXPECT_EQ(0, log.oldest_stored_idx()); + EXPECT_EQ(0, log.last_idx()); + EXPECT_EQ(1, log.next_idx()); + EXPECT_EQ(0, log.last_term()); + + LogEntry e; + e.set_term(1); + e.set_hash_chain("12345678901234567890123456789012"); + ASSERT_EQ(error::OK, log.Append(e, 1)); + EXPECT_EQ(1, log.oldest_stored_idx()); + EXPECT_EQ(1, log.last_idx()); + EXPECT_EQ(2, log.next_idx()); + EXPECT_EQ(1, log.last_term()); + + e.set_term(2); + e.set_hash_chain("12345678901234567890123456789012"); + ASSERT_EQ(error::OK, log.Append(e, 1)); + EXPECT_EQ(1, log.oldest_stored_idx()); + EXPECT_EQ(2, log.last_idx()); + EXPECT_EQ(3, log.next_idx()); + EXPECT_EQ(2, log.last_term()); + + auto i1 = log.At(4); + EXPECT_FALSE(i1.Valid()); + auto i2 = log.At(0); + EXPECT_FALSE(i2.Valid()); + auto i3 = log.At(1); + EXPECT_TRUE(i3.Valid()); + EXPECT_EQ(1, i3.Index()); + EXPECT_EQ(1, i3.Term()); + EXPECT_EQ(36, i3.SerializedSize()); + i3.Next(); + EXPECT_TRUE(i3.Valid()); + EXPECT_EQ(2, i3.Index()); + EXPECT_EQ(2, i3.Term()); + EXPECT_EQ(36, i3.SerializedSize()); + i3.Next(); + EXPECT_FALSE(i3.Valid()); + EXPECT_EQ(0, i3.Index()); + EXPECT_EQ(0, i3.Term()); + EXPECT_EQ(0, i3.SerializedSize()); + + EXPECT_EQ(1, log.oldest_stored_idx()); + EXPECT_EQ(2, log.last_idx()); + EXPECT_EQ(3, log.next_idx()); + EXPECT_EQ(2, log.last_term()); + auto i4 = log.At(2); + EXPECT_TRUE(i4.Valid()); + EXPECT_EQ(2, i4.Index()); + EXPECT_EQ(2, i4.Term()); + EXPECT_EQ(36, i4.SerializedSize()); + i4.Next(); + EXPECT_FALSE(i4.Valid()); + EXPECT_EQ(0, i4.Index()); + EXPECT_EQ(0, i4.Term()); + EXPECT_EQ(0, i4.SerializedSize()); +} + +TEST_F(LogTest, RunningOutOfSpace) { + LogEntry e; + e.set_data("abc"); + e.set_hash_chain("12345678901234567890123456789012"); + e.set_term(1); + size_t s = Log::logentry_bytes_in_log(e); + ASSERT_EQ(s, 147); + Log log(s*3+1); + ASSERT_EQ(error::OK, log.Append(e, 1)); + ASSERT_EQ(error::OK, log.Append(e, 1)); + ASSERT_EQ(error::OK, log.Append(e, 1)); + ASSERT_EQ(error::Raft_LogOutOfSpace, log.Append(e, 1)); + ASSERT_EQ(3, log.last_idx()); + ASSERT_EQ(error::OK, log.Append(e, 2)); + ASSERT_EQ(error::Raft_LogOutOfSpace, log.Append(e, 2)); +} + +} // namespace svr2::raft diff --git a/enclave/raft/tests/membership.cc b/enclave/raft/tests/membership.cc new file mode 100644 index 0000000..2a5543c --- /dev/null +++ b/enclave/raft/tests/membership.cc @@ -0,0 +1,132 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +//TESTDEP gtest +//TESTDEP peerid +//TESTDEP context +//TESTDEP sip +//TESTDEP sender +//TESTDEP env +//TESTDEP env/test +//TESTDEP env +//TESTDEP util +//TESTDEP metrics +//TESTDEP proto +//TESTDEP protobuf-lite +//TESTDEP libsodium + +#include +#include +#include "peerid/peerid.h" +#include "raft/membership.h" +#include "env/env.h" +#include "util/log.h" + +namespace svr2::raft { + +class MembershipTest : public ::testing::Test { + protected: + static void SetUpTestCase() { + env::Init(); + } + + error::Error ValidProgression(const ReplicaGroup& g1, const ReplicaGroup& g2, const std::string& leader, size_t supermajority) { + auto [m1, err1] = Membership::FromProto(g1); + auto [m2, err2] = Membership::FromProto(g2); + CHECK(err1 == error::OK && err2 == error::OK); + peerid::PeerID leader_peer; + CHECK(error::OK == leader_peer.FromString(leader)); + auto err = Membership::ValidProgressionForLeader(leader_peer, *m1, *m2, supermajority); + LOG(INFO) << "ValidProgressionForLeader: " << err; + return err; + } +}; + +TEST_F(MembershipTest, FromProtoBadPeer) { + ReplicaGroup g; + g.add_replicas()->set_peer_id("invalid"); + auto [out, err] = Membership::FromProto(g); + EXPECT_EQ(out.get(), nullptr); + EXPECT_EQ(err, error::Peers_InvalidID); +} + +TEST_F(MembershipTest, FromProtoDuplicatePeer) { + ReplicaGroup g; + g.add_replicas()->set_peer_id("12345678901234567890123456789012"); + g.add_replicas()->set_peer_id("12345678901234567890123456789012"); + auto [out, err] = Membership::FromProto(g); + EXPECT_EQ(out.get(), nullptr); + EXPECT_EQ(err, error::Membership_DuplicateReplicaInReplicaGroup); +} + +TEST_F(MembershipTest, FromProtoSuccess) { + ReplicaGroup g; + g.add_replicas()->set_peer_id("REPLICA........................0"); + g.add_replicas()->set_peer_id("REPLICA........................1"); + g.add_replicas()->set_peer_id("REPLICA........................2"); + g.add_replicas()->set_peer_id("REPLICA........................3"); + g.mutable_replicas(1)->set_voting(true); + g.mutable_replicas(2)->set_voting(true); + auto [out, err] = Membership::FromProto(g); + EXPECT_NE(out.get(), nullptr); + EXPECT_EQ(err, error::OK); + EXPECT_EQ(4, out->all_replicas().size()); + EXPECT_EQ(2, out->voting_replicas().size()); + EXPECT_EQ(1, out->voting_replicas().count(peerid::PeerID(reinterpret_cast("REPLICA........................1")))); + EXPECT_EQ(1, out->voting_replicas().count(peerid::PeerID(reinterpret_cast("REPLICA........................2")))); +} + +TEST_F(MembershipTest, ValidProgressionForLeader) { + ReplicaGroup g1; + g1.add_replicas()->set_peer_id("12345678901234567890123456789012"); + g1.add_replicas()->set_peer_id("22345678901234567890123456789012"); + g1.add_replicas()->set_peer_id("32345678901234567890123456789012"); + g1.add_replicas()->set_peer_id("42345678901234567890123456789012"); + g1.mutable_replicas(0)->set_voting(true); + g1.mutable_replicas(1)->set_voting(true); + g1.mutable_replicas(2)->set_voting(true); + + auto leader = g1.replicas(0).peer_id(); + + ReplicaGroup g2 = g1; + EXPECT_EQ(error::Membership_NoMembershipChanges, ValidProgression(g1, g2, leader, 0)); + + g2 = g1; + g2.mutable_replicas(0)->set_voting(false); + EXPECT_EQ(error::Membership_LeaderRemovedFromVoting, ValidProgression(g1, g2, leader, 0)); + + g2 = g1; + g2.mutable_replicas(1)->set_voting(false); + g2.mutable_replicas(2)->set_voting(false); + EXPECT_EQ(error::Membership_TooManyMembershipChanges, ValidProgression(g1, g2, leader, 0)); + + g2 = g1; + g2.mutable_replicas()->erase(g2.mutable_replicas()->begin()); + EXPECT_EQ(error::Membership_LeaderRemovedFromVoting, ValidProgression(g1, g2, leader, 0)); + + // Delete a voting (non-leader) replica entirely. + g2 = g1; + g2.mutable_replicas()->erase(++g2.mutable_replicas()->begin()); + EXPECT_EQ(error::OK, ValidProgression(g1, g2, leader, 0)); +} + +TEST_F(MembershipTest, MembershipCannotShrinkToOrBelowSupermajority) { + ReplicaGroup g1; + g1.add_replicas()->set_peer_id("12345678901234567890123456789012"); + g1.add_replicas()->set_peer_id("22345678901234567890123456789012"); + g1.add_replicas()->set_peer_id("32345678901234567890123456789012"); + g1.add_replicas()->set_peer_id("42345678901234567890123456789012"); + g1.mutable_replicas(0)->set_voting(true); + g1.mutable_replicas(1)->set_voting(true); + g1.mutable_replicas(2)->set_voting(true); + + auto leader = g1.replicas(0).peer_id(); + + ReplicaGroup g2 = g1; + g2.mutable_replicas(1)->set_voting(false); + EXPECT_EQ(error::Membership_SuperMajorityLost, ValidProgression(g1, g2, leader, 2)); + g2.mutable_replicas(2)->set_voting(false); + EXPECT_EQ(error::Membership_SuperMajorityLost, ValidProgression(g1, g2, leader, 2)); +} + +} // namespace svr2::raft diff --git a/enclave/raft/tests/raft.cc b/enclave/raft/tests/raft.cc new file mode 100644 index 0000000..3428722 --- /dev/null +++ b/enclave/raft/tests/raft.cc @@ -0,0 +1,277 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +//TESTDEP gtest +//TESTDEP peerid +//TESTDEP context +//TESTDEP hmac +//TESTDEP sip +//TESTDEP sender +//TESTDEP env +//TESTDEP env/test +//TESTDEP env +//TESTDEP util +//TESTDEP metrics +//TESTDEP proto +//TESTDEP protobuf-lite +//TESTDEP noise-c +//TESTDEP libsodium + +#include +#include "raft/raft.h" +#include "peerid/peerid.h" +#include "env/env.h" +#include "util/log.h" +#include "proto/e2e.pb.h" +#include + +namespace svr2::raft { + +class RaftTest : public ::testing::Test { + protected: + static void SetUpTestCase() { + env::Init(); + } + + enclaveconfig::RaftConfig DefaultConfig() { + enclaveconfig::RaftConfig config; + config.set_election_ticks(5); + config.set_heartbeat_ticks(1); + config.set_replication_chunk_bytes(1<<20); + config.set_replica_voting_timeout_ticks(15); + config.set_replica_membership_timeout_ticks(30); + return config; + } + + void SetUpRaft(int size, enclaveconfig::RaftConfig config) { + // Create a size-3 raft group + ReplicaGroup g; + std::set peers; + for (int i = 0; i < size; i++) { + uint8_t peer_id[32]; + ASSERT_EQ(error::OK, env::environment->RandomBytes(peer_id, sizeof(peer_id))); + peerid::PeerID p(peer_id); + auto r = g.add_replicas(); + p.ToString(r->mutable_peer_id()); + r->set_voting(true); + peers.insert(p); + } + auto [mem, err] = Membership::FromProto(g); + ASSERT_EQ(error::OK, err); + for (auto peer : peers) { + auto memcpy = std::make_unique(*mem); + auto r = std::make_unique( + 1, // group + peer, + std::move(memcpy), + std::move(std::make_unique(1<<20)), // 1MB log + config, + false, + 0); + group_[peer] = std::move(r); + } + } + + void RouteMessages() { + bool quiescent = false; + int iter = 0; + LOG(INFO) << "--------------------- Message routing"; + while (!quiescent) { + LOG(INFO) << "------------- iteration " << iter++; + quiescent = true; + std::map> send; + for (auto i = group_.begin(); i != group_.end(); ++i) { + send[i->first] = i->second->SendableMessages(); + } + for (auto i = send.begin(); i != send.end(); ++i) { + for (auto msg : i->second) { + quiescent = false; + std::set send_to; + if (msg.to().has_value()) { + send_to.insert(*msg.to()); + } else if (group_.count(i->first)) { + send_to = group_[i->first]->peers(); + } else { + LOG(INFO) << "dropping targetted send to nonexistent peer " << i->first; + continue; + } + for (auto peer : send_to) { + if (group_.count(peer) == 0) { + LOG(INFO) << "dropping broadcast send to nonexistent peer " << peer; + continue; + } + LOG(VERBOSE) << " >>> send from " << i->first << " to " << peer; + LOG(VERBOSE) << " ::: " << MsgStr(msg.message()); + group_[peer]->Receive(&ctx, msg.message(), i->first); + LOG(VERBOSE) << " <<< send complete"; + } + } + } + } + } + + void CommitOnAll(const enclaveconfig::RaftConfig& config) { + bool quiescent = false; + std::map committed; + LOG(INFO) << "Waiting for commits to quiesce on all replicas"; + for (auto i = group_.begin(); i != group_.end(); ++i) { + committed[i->first] = i->second->commit_idx(); + LOG(INFO) << " initial on " << i->first << " : " << i->second->commit_idx(); + } + RouteMessages(); + while (!quiescent) { + LOG(INFO) << "TICK"; + quiescent = true; + for (int i = 0; i < config.heartbeat_ticks(); i++) { + for (auto i = group_.begin(); i != group_.end(); ++i) { + i->second->TimerTick(&ctx); + } + RouteMessages(); + } + for (auto i = group_.begin(); i != group_.end(); ++i) { + if (committed[i->first] != i->second->commit_idx()) { + quiescent = false; + committed[i->first] = i->second->commit_idx(); + LOG(INFO) << " update on " << i->first << " : " << i->second->commit_idx(); + } + } + } + LOG(INFO) << "Commits quiesced"; + } + + peerid::PeerID ElectLeader(const enclaveconfig::RaftConfig& config) { + LOG(INFO) << "Electing leader"; + std::set leaders; + while (leaders.size() == 0) { + RouteMessages(); + for (int i = 0; i < config.election_ticks() * 3; i++) { + for (auto i = group_.begin(); i != group_.end(); ++i) { + if (i->second->is_leader()) { leaders.insert(i->first); } + } + if (leaders.size()) break; + LOG(INFO) << "Tick: " << i; + for (auto i = group_.begin(); i != group_.end(); ++i) { + i->second->TimerTick(&ctx); + } + RouteMessages(); + } + } + CHECK(leaders.size() == 1); + LOG(INFO) << "Elected leader: " << leaders.begin()->DebugString(); + return *leaders.begin(); + } + + std::map> group_; + context::Context ctx; +}; + +TEST_F(RaftTest, CommitOnAll) { + auto config = DefaultConfig(); + SetUpRaft(3, config); + // Get a leader + peerid::PeerID leader = ElectLeader(config); + LOG(INFO) "============== SENDING LOG TO LEADER " << leader; + auto [loc, err] = group_[leader]->ClientRequest(&ctx, "abc"); + ASSERT_EQ(error::OK, err); + EXPECT_GE(loc.term(), 1); // may have been a few terms to elect leader + EXPECT_GE(loc.idx(), 1); // leader election adds entry to log + CommitOnAll(config); + for (auto i = group_.begin(); i != group_.end(); ++i) { + std::string last_log; + LOG(INFO) << "replica logs for " << i->first; + while(true) { + auto [idx, e] = i->second->TakeCommittedLog(); + if (idx == 0) break; + last_log = e.data(); + LOG(INFO) << "\tidx: " << idx << " : " << last_log; + if (last_log == "abc") { + EXPECT_EQ(idx, loc.idx()); + EXPECT_EQ(e.term(), loc.term()); + } + } + ASSERT_EQ(last_log, "abc"); + } +} + +TEST_F(RaftTest, CommitIfOneDown) { + auto config = DefaultConfig(); + SetUpRaft(3, config); + // Remove one of the participants. + group_.erase(group_.begin()); + peerid::PeerID leader = ElectLeader(config); + LOG(INFO) "============== SENDING LOG TO LEADER " << leader; + auto [loc, err] = group_[leader]->ClientRequest(&ctx, "abc"); + ASSERT_EQ(error::OK, err); + EXPECT_GE(loc.term(), 1); // may have been a few terms to elect leader + EXPECT_GE(loc.idx(), 1); // leader election adds entry to log + CommitOnAll(config); + for (auto i = group_.begin(); i != group_.end(); ++i) { + std::string last_log; + LOG(INFO) << "replica logs for " << i->first; + while(true) { + auto [idx, e] = i->second->TakeCommittedLog(); + if (idx == 0) break; + last_log = e.data(); + LOG(INFO) << "\tidx: " << idx << " : " << last_log; + if (last_log == "abc") { + EXPECT_EQ(idx, loc.idx()); + EXPECT_EQ(e.term(), loc.term()); + } + } + ASSERT_EQ(last_log, "abc"); + } +} + +TEST_F(RaftTest, SingleReplicaGroup) { + auto config = DefaultConfig(); + SetUpRaft(1, DefaultConfig()); + peerid::PeerID leader = ElectLeader(config); + LOG(INFO) "============== SENDING LOG TO LEADER " << leader; + auto [loc, err] = group_[leader]->ClientRequest(&ctx, "abc"); + ASSERT_EQ(error::OK, err); + EXPECT_GE(loc.term(), 1); // may have been a few terms to elect leader + EXPECT_GE(loc.idx(), 1); // leader election adds entry to log + CommitOnAll(config); + for (auto i = group_.begin(); i != group_.end(); ++i) { + std::string last_log; + LOG(INFO) << "replica logs for " << i->first; + while(true) { + auto [idx, e] = i->second->TakeCommittedLog(); + if (idx == 0) break; + last_log = e.data(); + LOG(INFO) << "\tidx: " << idx << " : " << last_log; + if (last_log == "abc") { + EXPECT_EQ(idx, loc.idx()); + EXPECT_EQ(e.term(), loc.term()); + } + } + ASSERT_EQ(last_log, "abc"); + } +} + +TEST_F(RaftTest, QuorumSize) { + EXPECT_EQ(Raft::quorum_size(3, 0), 2); + EXPECT_EQ(Raft::quorum_size(4, 0), 3); + EXPECT_EQ(Raft::quorum_size(2, 0), 2); + EXPECT_EQ(Raft::quorum_size(3, 1), 3); + EXPECT_EQ(Raft::quorum_size(4, 1), 3); +} + +TEST_F(RaftTest, RelinquishLeadership) { + auto config = DefaultConfig(); + SetUpRaft(2, DefaultConfig()); + auto leader = ElectLeader(config); + LOG(INFO) << "==================== LEADER " << leader << " calling RelinquishLeadership"; + group_[leader]->RelinquishLeadership(&ctx); + RouteMessages(); + EXPECT_FALSE(group_[leader]->is_leader()); + peerid::PeerID expected_new_leader; + for (auto iter = group_.begin(); iter != group_.end(); ++iter) { + if (iter->first == leader) { continue; } + expected_new_leader = iter->first; + } + EXPECT_TRUE(expected_new_leader.Valid()); + EXPECT_TRUE(group_[expected_new_leader]->is_leader()); +} + +} // namespace svr2::raft diff --git a/enclave/raft/tests/setdiffsize.cc b/enclave/raft/tests/setdiffsize.cc new file mode 100644 index 0000000..cff1fe1 --- /dev/null +++ b/enclave/raft/tests/setdiffsize.cc @@ -0,0 +1,48 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +//TESTDEP gtest +//TESTDEP peerid +//TESTDEP context +//TESTDEP sip +//TESTDEP sender +//TESTDEP env +//TESTDEP env/test +//TESTDEP env +//TESTDEP util +//TESTDEP metrics +//TESTDEP proto +//TESTDEP protobuf-lite +//TESTDEP libsodium + +#include +#include +#include "peerid/peerid.h" +#include "raft/membership.h" + +namespace svr2::raft { + +class SetDiffTest : public ::testing::Test {}; + +TEST_F(SetDiffTest, Basic) { + std::set a; + std::set b; + ASSERT_EQ(0, SetDiffSize(a, b)); + uint8_t p1[32] = {1}; + uint8_t p2[32] = {2}; + uint8_t p3[32] = {3}; + uint8_t p4[32] = {4}; + a.insert(peerid::PeerID(p1)); + ASSERT_EQ(1, SetDiffSize(a, b)); + ASSERT_EQ(0, SetDiffSize(b, a)); + b.insert(peerid::PeerID(p2)); + b.insert(peerid::PeerID(p3)); + b.insert(peerid::PeerID(p4)); + ASSERT_EQ(1, SetDiffSize(a, b)); + ASSERT_EQ(3, SetDiffSize(b, a)); + a.insert(peerid::PeerID(p4)); + ASSERT_EQ(1, SetDiffSize(a, b)); + ASSERT_EQ(2, SetDiffSize(b, a)); +} + +} // namespace svr2::raft diff --git a/enclave/raft/types.h b/enclave/raft/types.h new file mode 100644 index 0000000..8bf52e3 --- /dev/null +++ b/enclave/raft/types.h @@ -0,0 +1,17 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_RAFT_TYPES_H__ +#define __SVR2_RAFT_TYPES_H__ + +#include + +namespace svr2::raft { + +typedef uint64_t LogIdx; +typedef uint64_t TermId; +typedef uint64_t GroupId; + +} // namespace svr2::raft + +#endif // __SVR2_RAFT_TYPES_H__ diff --git a/enclave/releases/.keep b/enclave/releases/.keep new file mode 100644 index 0000000..e69de29 diff --git a/enclave/releases/default.6ee1042f9e20f880326686dd4ba50c25359f01e9f733eeba4382bca001d45094 b/enclave/releases/default.6ee1042f9e20f880326686dd4ba50c25359f01e9f733eeba4382bca001d45094 new file mode 100644 index 0000000..36c52cf Binary files /dev/null and b/enclave/releases/default.6ee1042f9e20f880326686dd4ba50c25359f01e9f733eeba4382bca001d45094 differ diff --git a/enclave/releases/small.a8a261420a6bb9b61aa25bf8a79e8bd20d7652531feb3381cbffd446d270be95 b/enclave/releases/small.a8a261420a6bb9b61aa25bf8a79e8bd20d7652531feb3381cbffd446d270be95 new file mode 100644 index 0000000..7a68c83 Binary files /dev/null and b/enclave/releases/small.a8a261420a6bb9b61aa25bf8a79e8bd20d7652531feb3381cbffd446d270be95 differ diff --git a/enclave/sender/sender.cc b/enclave/sender/sender.cc new file mode 100644 index 0000000..a1eee1e --- /dev/null +++ b/enclave/sender/sender.cc @@ -0,0 +1,19 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include "sender/sender.h" +#include "env/env.h" +#include "metrics/metrics.h" + +namespace svr2::sender { + +// Send a message to the host. +void Send(const EnclaveMessage& msg) { + std::string serialized; + CHECK(msg.SerializeToString(&serialized)); + CHECK(error::OK == env::environment->SendMessage(serialized)); + COUNTER(sender, enclave_messages_sent)->Increment(); + COUNTER(sender, enclave_bytes_sent)->IncrementBy(serialized.size()); +} + +} // namespace svr2::sender diff --git a/enclave/sender/sender.h b/enclave/sender/sender.h new file mode 100644 index 0000000..514c23f --- /dev/null +++ b/enclave/sender/sender.h @@ -0,0 +1,17 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_SENDER_SENDER_H__ +#define __SVR2_SENDER_SENDER_H__ + +#include "proto/msgs.pb.h" +#include "proto/error.pb.h" + +namespace svr2::sender { + +// Send a message to the host. +void Send(const EnclaveMessage& msg); + +} // namespace svr2::sender + +#endif // __SVR2_SENDER_SENDER_H__ diff --git a/enclave/sender/tests/sender.cc b/enclave/sender/tests/sender.cc new file mode 100644 index 0000000..c286e3d --- /dev/null +++ b/enclave/sender/tests/sender.cc @@ -0,0 +1,37 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +//TESTDEP gtest +//TESTDEP env +//TESTDEP env/test +//TESTDEP metrics +//TESTDEP proto +//TESTDEP protobuf-lite +//TESTDEP libsodium + +#include +#include +#include +#include "proto/msgs.pb.h" +#include "proto/error.pb.h" +#include "env/env.h" +#include "env/test/test.h" +#include "sender/sender.h" + +namespace svr2::sender { + +TEST(SenderTest, SendViaTestEnv) { + env::Init(); + EnclaveMessage m; + m.mutable_peer_message()->set_syn("abc"); + Send(m); + Send(m); + Send(m); + std::vector got = env::test::SentMessages(); + ASSERT_EQ(3, got.size()); + ASSERT_EQ("abc", got[0].peer_message().syn()); + ASSERT_EQ("abc", got[1].peer_message().syn()); + ASSERT_EQ("abc", got[2].peer_message().syn()); +} + +} // namespace svr2::sender diff --git a/enclave/sip/halfsiphash.c b/enclave/sip/halfsiphash.c new file mode 120000 index 0000000..01cccc7 --- /dev/null +++ b/enclave/sip/halfsiphash.c @@ -0,0 +1 @@ +../SipHash/halfsiphash.c \ No newline at end of file diff --git a/enclave/sip/halfsiphash.h b/enclave/sip/halfsiphash.h new file mode 120000 index 0000000..1933811 --- /dev/null +++ b/enclave/sip/halfsiphash.h @@ -0,0 +1 @@ +../SipHash/halfsiphash.h \ No newline at end of file diff --git a/enclave/sip/hasher.cc b/enclave/sip/hasher.cc new file mode 100644 index 0000000..540d241 --- /dev/null +++ b/enclave/sip/hasher.cc @@ -0,0 +1,25 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include "sip/hasher.h" +#include "env/env.h" +#include "util/endian.h" +extern "C" { +#include "sip/halfsiphash.h" +} // extern "C" + +namespace svr2::sip { + +Hasher::Hasher() { + CHECK(error::OK == env::environment->RandomBytes(halfsiphash_key_, sizeof(halfsiphash_key_))); +} +Hasher::Hasher(const Hasher& copy) { + memcpy(halfsiphash_key_, copy.halfsiphash_key_, sizeof(halfsiphash_key_)); +} +size_t Hasher::Hash(const void* data, size_t size) const { + uint8_t out[8]; + halfsiphash(data, size, halfsiphash_key_, out, sizeof(out)); + return util::BigEndian64FromBytes(out); +} + +} // namespace svr2::sip diff --git a/enclave/sip/hasher.h b/enclave/sip/hasher.h new file mode 100644 index 0000000..6e9f3f4 --- /dev/null +++ b/enclave/sip/hasher.h @@ -0,0 +1,24 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_SIP_HASHER_H__ +#define __SVR2_SIP_HASHER_H__ + +#include +#include + +namespace svr2::sip { + +class Hasher { + public: + Hasher(); + Hasher(const Hasher& copy); + protected: + size_t Hash(const void* data, size_t bytes) const; + private: + uint8_t halfsiphash_key_[8]; +}; + +} // namespace svr2::sip + +#endif // __SVR2_SIP_HASHER_H__ diff --git a/enclave/sip/tests/hasher.cc b/enclave/sip/tests/hasher.cc new file mode 100644 index 0000000..1f39f45 --- /dev/null +++ b/enclave/sip/tests/hasher.cc @@ -0,0 +1,44 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +//TESTDEP sip +//TESTDEP env +//TESTDEP env/test +//TESTDEP gtest +//TESTDEP proto +//TESTDEP protobuf-lite +//TESTDEP libsodium + +#include +#include +#include +#include "sip/hasher.h" +#include "env/env.h" + +namespace svr2::sip { + +class HashInts : public Hasher { + public: + size_t operator()(const uint32_t& a) const { + return Hash(&a, sizeof(a)); + } +}; + +class HasherTest : public ::testing::Test { + protected: + static void SetUpTestSuite() { + env::Init(); + } +}; + +TEST_F(HasherTest, HashInts) { + std::unordered_map m; + for (uint32_t i = 0; i < 5000; i++) { + m[i] = i; + } + for (uint32_t i = 0; i < 5000; i++) { + ASSERT_EQ(m[i], i); + } +} + +} // namespace svr2::sip diff --git a/enclave/socketwrap/socket.cc b/enclave/socketwrap/socket.cc new file mode 100644 index 0000000..6c66891 --- /dev/null +++ b/enclave/socketwrap/socket.cc @@ -0,0 +1,95 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include +#include +#include +#include +#include + +#include "socketwrap/socket.h" + +#include "context/context.h" +#include "util/log.h" +#include "util/endian.h" + +namespace svr2::socketwrap { + +Socket::Socket(int fd) : fd_(fd) {} + +error::Error Socket::ReadAll(uint8_t* buf, size_t size) { + while (size) { + ssize_t got = recv(fd_, buf, size, 0); + if (got == 0) { + return COUNTED_ERROR(Socket_ReadEOF); + } else if (got < 0) { + switch (errno) { + case EINTR: + continue; + default: + LOG(ERROR) << "Socket " << fd_ << " recv error: " << errno << " - " << strerror(errno); + return COUNTED_ERROR(Socket_Read); + } + } else { + size -= got; + buf += got; + } + } + return error::OK; +} + +error::Error Socket::WriteAll(uint8_t* buf, size_t size) { + while (size) { + ssize_t got = send(fd_, buf, size, MSG_NOSIGNAL); + if (got < 0) { + switch (errno) { + case EINTR: + continue; + default: + LOG(ERROR) << "Socket " << fd_ << " send error: " << errno << " - " << strerror(errno); + return COUNTED_ERROR(Socket_Write); + } + } else { + size -= got; + buf += got; + } + } + return error::OK; +} + +error::Error Socket::ReadPB(context::Context* ctx, google::protobuf::MessageLite* pb) { + ACQUIRE_LOCK(read_mu_, ctx, lock_socket_read); + uint8_t uint32_buf[4] = {0}; + RETURN_IF_ERROR(ReadAll(uint32_buf, sizeof(uint32_buf))); + size_t to_read = util::BigEndian32FromBytes(uint32_buf); + if (to_read > INT32_MAX) { + return COUNTED_ERROR(Socket_ReadTooBig); + } + if (read_buf_.size() < to_read) { + read_buf_.resize(to_read); + } + RETURN_IF_ERROR(ReadAll(read_buf_.data(), to_read)); + if (!pb->ParseFromArray(read_buf_.data(), to_read)) { + return COUNTED_ERROR(Socket_ParseIncoming); + } + return error::OK; +} + +error::Error Socket::WritePB(context::Context* ctx, const google::protobuf::MessageLite& pb) { + ACQUIRE_LOCK(write_mu_, ctx, lock_socket_write); + size_t size = pb.ByteSizeLong(); + if (size > INT32_MAX) { + return COUNTED_ERROR(Socket_WriteTooBig); + } + write_buf_.resize(size); + uint8_t* end = pb.SerializeWithCachedSizesToArray(write_buf_.data()); + size = end - write_buf_.data(); + + uint8_t uint32_buf[4] = {0}; + util::BigEndian32Bytes(size, uint32_buf); + RETURN_IF_ERROR(WriteAll(uint32_buf, sizeof(uint32_buf))); + RETURN_IF_ERROR(WriteAll(write_buf_.data(), size)); + return error::OK; +} + +} // namespace svr2::socketwrap diff --git a/enclave/socketwrap/socket.h b/enclave/socketwrap/socket.h new file mode 100644 index 0000000..1aaa0aa --- /dev/null +++ b/enclave/socketwrap/socket.h @@ -0,0 +1,39 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_SOCKET_SOCKET_H__ +#define __SVR2_SOCKET_SOCKET_H__ + +#include +#include "proto/error.pb.h" +#include "context/context.h" +#include "util/mutex.h" + +#include + +namespace svr2::socketwrap { + +class Socket { + public: + Socket(int fd); // `fd` should already be bound/listening. + + error::Error ReadPB(context::Context* ctx, google::protobuf::MessageLite* pb) EXCLUDES(read_mu_); + error::Error WritePB(context::Context* ctx, const google::protobuf::MessageLite& pb) EXCLUDES(write_mu_); + + public_for_test: + error::Error ReadAll(uint8_t* buf, size_t size); + error::Error WriteAll(uint8_t* buf, size_t size); + + private: + int fd_; + util::mutex read_mu_; + util::mutex write_mu_; + // Reusable buffers for reading/writing. Will grow to be the max size + // of all messages they've seen. + std::vector read_buf_ GUARDED_BY(read_mu_); + std::vector write_buf_ GUARDED_BY(write_mu_); +}; + +} // namespace svr2::socketwrap + +#endif // __SVR2_SOCKET_SOCKET_H__ diff --git a/enclave/socketwrap/tests/socket.cc b/enclave/socketwrap/tests/socket.cc new file mode 100644 index 0000000..43d32e2 --- /dev/null +++ b/enclave/socketwrap/tests/socket.cc @@ -0,0 +1,91 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +//TESTDEP gtest +//TESTDEP socketwrap +//TESTDEP util +//TESTDEP env +//TESTDEP env/test +//TESTDEP context +//TESTDEP metrics +//TESTDEP proto +//TESTDEP protobuf-lite +//TESTDEP libsodium +#include +#include "socketwrap/socket.h" +#include "proto/tests.pb.h" +#include "env/env.h" +#include "util/log.h" +#include "util/endian.h" +#include +#include + +namespace svr2::socketwrap { + +class SocketTest : public ::testing::Test { + protected: + static void SetUpTestCase() { + env::Init(); + } +}; + +TEST_F(SocketTest, SendAndReceive) { + int socks[2]; + ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_STREAM, 0, socks)); + + Socket a(socks[0]); + Socket b(socks[1]); + + tests::SimplePB p1; + p1.set_str("abcdefg"); + { + context::Context ctx; + ASSERT_EQ(error::OK, a.WritePB(&ctx, p1)); + tests::SimplePB p2; + ASSERT_EQ(error::OK, b.ReadPB(&ctx, &p2)); + ASSERT_EQ(p1.str(), p2.str()); + } + + for (int i = 0; i < 10; i++) { + context::Context ctx; + ASSERT_EQ(error::OK, b.WritePB(&ctx, p1)); + } + for (int i = 0; i < 10; i++) { + context::Context ctx; + tests::SimplePB p2; + ASSERT_EQ(error::OK, a.ReadPB(&ctx, &p2)); + ASSERT_EQ(p1.str(), p2.str()); + } + + ASSERT_EQ(0, shutdown(socks[0], SHUT_WR)); + { + context::Context ctx; + EXPECT_EQ(error::Socket_Write, a.WritePB(&ctx, p1)); + } + { + context::Context ctx; + tests::SimplePB p2; + EXPECT_EQ(error::Socket_ReadEOF, b.ReadPB(&ctx, &p2)); + } + + ASSERT_EQ(0, close(socks[0])); + ASSERT_EQ(0, close(socks[1])); +} + +TEST_F(SocketTest, ReadTooBig) { + int socks[2]; + ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_STREAM, 0, socks)); + + Socket a(socks[0]); + Socket b(socks[1]); + uint8_t too_big_buf[4] = {0xff, 0xff, 0xff, 0xff}; + a.WriteAll(too_big_buf, sizeof(too_big_buf)); + + tests::SimplePB pb; + context::Context ctx; + ASSERT_EQ(error::Socket_ReadTooBig, b.ReadPB(&ctx, &pb)); + ASSERT_EQ(0, close(socks[0])); + ASSERT_EQ(0, close(socks[1])); +} + +} // namespace svr2::socketwrap diff --git a/enclave/svr2.conf b/enclave/svr2.conf new file mode 100644 index 0000000..eb6f220 --- /dev/null +++ b/enclave/svr2.conf @@ -0,0 +1,21 @@ +Debug=0 +# Each TCS requires its own stack, plus a few "bonus" pages: +# +# - 1 TCS page (enclave independent) +# - 2 State Save Area (SSA) pages (enclave independent) +# - 1 guard page (enclave independent) +# - 1 TLS page (depends on enclave binary - number of pages needed to hold .tdata and .tbss in the enclave.signed elf file. At time of writing this fits in 1 page.) +# - 1 page for thread-specific data (TSD) slots +# +# ...and so each TCS consumes (6 + NumStackPages) EPC pages, so NumHeapPages = NumEpcPages - (NumTCS * (6 + NumStackPages)). +# +# On top of that, attestation services may consume additional EPC memory (6 MiB in our case). +# +# This configuration requires a host that has at least 120 GiB EPC memory available. +# +# 120 GiB - 6 MiB => 128842727424 bytes => NumEpcPages = 31457280 +NumHeapPages=31324288 +NumStackPages=2048 +NumTCS=64 +ProductID=1 +SecurityVersion=1 diff --git a/enclave/svr2/.keep b/enclave/svr2/.keep new file mode 100644 index 0000000..9af1401 --- /dev/null +++ b/enclave/svr2/.keep @@ -0,0 +1 @@ +This file exists so that Git knows to keep this directory around. diff --git a/enclave/svr2_small.conf b/enclave/svr2_small.conf new file mode 100644 index 0000000..1859eda --- /dev/null +++ b/enclave/svr2_small.conf @@ -0,0 +1,7 @@ +Debug=0 +# Roughly 8G +NumHeapPages=2000000 +NumStackPages=2048 +NumTCS=16 +ProductID=1 +SecurityVersion=1 diff --git a/enclave/svr2_test.conf b/enclave/svr2_test.conf new file mode 100644 index 0000000..1277e38 --- /dev/null +++ b/enclave/svr2_test.conf @@ -0,0 +1,7 @@ +# Enclave settings: a small enclave with debug enabled for testing purposes. +Debug=1 +NumHeapPages=20000 +NumStackPages=2048 +NumTCS=64 +ProductID=1 +SecurityVersion=1 diff --git a/enclave/test_deps.sh b/enclave/test_deps.sh new file mode 100755 index 0000000..40d218d --- /dev/null +++ b/enclave/test_deps.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +for testfile in `find ./ -type f | grep /tests/ | grep cc$`; do + testfile="$(echo "$testfile" | sed 's#./##')" + testname="$(echo "$testfile" | sed 's/\.cc$/\.test/')" + echo 1>&2 "TEST: $testfile -> $testname" + deps="build/$(dirname $(dirname "$testfile"))/TEST.a$(grep '^//TESTDEP ' "$testfile" | awk '{printf " build/%s/TEST.a",$2}')" + echo 1>&2 " Deps: $deps" + args="$(grep '^//TESTARG ' "$testfile" | awk '{printf "%s ",$2}')" + echo 1>&2 " Args: $args" + echo "build/$testname: $testfile $deps" + echo -e '\t$(QUIET) echo -e "BUILD\t$@"' + echo -e '\t$(QUIET) mkdir -p \$(@D)' + echo -e "\t\$(QUIET) \$(CXX) \$(TEST_CXXFLAGS) -o \$@ $testfile $deps $args \$(TEST_LDFLAGS)" + echo "test: build/$testname.out" + echo ".PRECIOUS: build/$testname build/$testname.out" +done | tee .testdepends diff --git a/enclave/testhost/testhost.cc b/enclave/testhost/testhost.cc new file mode 100644 index 0000000..7bf2939 --- /dev/null +++ b/enclave/testhost/testhost.cc @@ -0,0 +1,264 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "../../host/enclave/c/svr2_u.h" +#include "proto/e2e.pb.h" +#include "proto/enclaveconfig.pb.h" +#include "proto/error.pb.h" +#include "proto/msgs.pb.h" +#include "util/constant.h" +#include "util/macros.h" +#include "attestation/attestation.h" + +// OCALL implementation +static std::deque<::svr2::EnclaveMessage> out_msgs; + +void svr2_output_message(size_t msg_size, unsigned char* msg) { + fprintf(stderr, "received message\n"); + ::svr2::EnclaveMessage em; + CHECK(em.ParseFromArray(msg, msg_size)); + out_msgs.emplace_back(std::move(em)); +} + +namespace { + +oe_enclave_t* create_enclave(const char* enclave_path, uint32_t flags) { + oe_enclave_t* enclave = NULL; + + printf("Host: Enclave library %s\n", enclave_path); + oe_result_t result = oe_create_svr2_enclave( + enclave_path, OE_ENCLAVE_TYPE_AUTO, flags, NULL, 0, &enclave); + + if (result != OE_OK) { + printf("Host: oe_create_attestation_enclave failed. %s", + oe_result_str(result)); + } else { + printf("Host: Enclave successfully created.\n"); + } + return enclave; +} + +void terminate_enclave(oe_enclave_t* enclave) { + oe_terminate_enclave(enclave); + printf("Host: Enclave successfully terminated.\n"); +} + +static void print_peer_id(std::array peer_id) { + fprintf(stderr, "{"); + for (size_t i = 0; i < 31; ++i) { + fprintf(stderr, "%u, ", peer_id[i]); + } + fprintf(stderr, "%u }\n", peer_id[31]); +} + +template +std::string BytesToString(std::array arr) { + std::string s(N, 0); + std::copy(arr.begin(), arr.end(), s.begin()); + return s; +} + +void VerifyAttestation(const std::string& evidence, + const std::string& endorsements, + const std::array& expected_id) { + auto [claims, claims_length] = svr2::attestation::VerifyAndReadClaims(evidence, endorsements); + auto free_claims_known_size = [claims_length=claims_length](oe_claim_t* ptr) { + return oe_free_claims(ptr, claims_length); + }; + std::unique_ptr free_claims( + claims, free_claims_known_size); + + // evidence is verified, now check id + std::array out{0}; + + ::svr2::error::Error err = + ::svr2::attestation::ReadKeyFromVerifiedClaims(claims, claims_length, out); + CHECK(::svr2::util::ConstantTimeEquals(out, expected_id)); +} + +class TestEnclave { + oe_enclave_t* enclave_; + std::array id_{0}; + uint32_t flags_; + + public: + TestEnclave(const char* enclave_path, uint32_t flags) : flags_(flags) { + this->enclave_ = create_enclave(enclave_path, flags); + } + ~TestEnclave() { terminate_enclave(this->enclave_); } + + bool is_simulated() const { return this->flags_ & OE_ENCLAVE_FLAG_SIMULATE; } + + ::svr2::error::Error Init(::svr2::enclaveconfig::EnclaveConfig config) { + int ret = ::svr2::error::OK; + std::string serialized_cfg; + CHECK(config.SerializeToString(&serialized_cfg)); + oe_result_t result = + svr2_init(this->enclave_, &ret, serialized_cfg.size(), + reinterpret_cast(serialized_cfg.data()), + this->id_.data()); + fprintf(stderr, "Created enclave with id: "); + print_peer_id(this->id_); + return ::svr2::error::OK; + } + + void Connect(TestEnclave& other) { + oe_result_t result = OE_OK; + int ret = ::svr2::error::OK; + + // request that peer 0 connect to peer 1 + // construct and serialize the H2E message + ::svr2::UntrustedMessage h2e_connect_cmd; + auto her = h2e_connect_cmd.mutable_h2e_request(); + her->set_request_id(1001); + + auto req = her->mutable_create_new_raft_group(); + req->set_min_voting_replicas(1); + req->set_max_voting_replicas(2); + + std::string serialized_req; + CHECK(h2e_connect_cmd.SerializeToString(&serialized_req)); + + // send command to enclave + result = svr2_input_message( + this->enclave_, &ret, serialized_req.size(), + reinterpret_cast(serialized_req.data())); + CHECK(result == OE_OK); + CHECK(ret == ::svr2::error::OK); + + // Get the peer message + svr2::EnclaveMessage emsg = std::move(out_msgs.front()); + out_msgs.pop_front(); + CHECK(emsg.inner_case() == svr2::EnclaveMessage::kPeerMessage); + + // if not simulating, extract the attestation and verify it + if (!this->is_simulated()) { + ::svr2::e2e::ConnectRequest conn_request; + CHECK(conn_request.ParseFromString(emsg.peer_message().data())); + + auto remote_attestation = conn_request.attestation(); + VerifyAttestation(remote_attestation.evidence(), + remote_attestation.endorsements(), this->id_); + } + + // get the HostToEnclaveResponse + svr2::EnclaveMessage h2e_response = std::move(out_msgs.front()); + out_msgs.pop_front(); + CHECK(h2e_response.inner_case() == + svr2::EnclaveMessage::kH2EResponse); + CHECK(h2e_response.h2e_response().status() == + ::svr2::error::OK); + + // Forward the peer message to peer 1 + svr2::UntrustedMessage e2e_connect_request; + *e2e_connect_request.mutable_peer_message() = + std::move(*emsg.mutable_peer_message()); + // The peer_id field on a PeerMessage is the ID of the sender + e2e_connect_request.mutable_peer_message()->set_peer_id( + BytesToString(this->id_)); + + CHECK(e2e_connect_request.SerializeToString(&serialized_req)); + + // send the peer message to other enclave + result = svr2_input_message( + other.enclave_, &ret, serialized_req.size(), + reinterpret_cast(serialized_req.data())); + CHECK(result == OE_OK); + CHECK(ret == ::svr2::error::OK); + + // the other enclave produces exactly one messge - a PeerMessage to finish + // the handshake + CHECK(out_msgs.size() == 1); + emsg = std::move(out_msgs.front()); + out_msgs.pop_front(); + CHECK(emsg.inner_case() == svr2::EnclaveMessage::kPeerMessage); + + // forward this message to our enclave + svr2::UntrustedMessage e2e_connect_response; + *e2e_connect_response.mutable_peer_message() = + std::move(*emsg.mutable_peer_message()); + + // this message is from `other` + e2e_connect_response.mutable_peer_message()->set_peer_id( + BytesToString(other.id_)); + + CHECK(e2e_connect_response.SerializeToString(&serialized_req)); + result = svr2_input_message( + this->enclave_, &ret, serialized_req.size(), + reinterpret_cast(serialized_req.data())); + CHECK(result == OE_OK); + CHECK(ret == ::svr2::error::OK); + + // There should be no message from our enclave + CHECK(out_msgs.size() == 0); + fprintf(stderr, "handshake successful\n"); + } +}; + +}; // namespace +bool check_simulate_opt(int* argc, const char* argv[]) { + for (int i = 0; i < *argc; i++) { + if (strcmp(argv[i], "--simulate") == 0) { + fprintf(stdout, "Running in simulation mode\n"); + memmove(&argv[i], &argv[i + 1], (*argc - i) * sizeof(char*)); + (*argc)--; + return true; + } + } + return false; +} + +int main(int argc, const char* argv[]) { + uint32_t flags = OE_ENCLAVE_FLAG_DEBUG; + // oe_uuid_t* format_id = nullptr; + + bool simulate = false; + if (check_simulate_opt(&argc, argv)) { + flags |= OE_ENCLAVE_FLAG_SIMULATE; + simulate = true; + } else { + CHECK(OE_OK == oe_verifier_initialize()); + } + + // check_simulate_opt decrements argc if it found `--simulate` + if (argc != 2) { + fprintf(stderr, "Usage: %s enclave_image_path [ --simulate ]\n", argv[0]); + return 1; + } + + printf("Host: Creating enclave 0\n"); + TestEnclave e0(argv[1], flags); + + printf("Host: Creating enclave 1\n"); + TestEnclave e1(argv[1], flags); + + { + // create a config pb + ::svr2::enclaveconfig::EnclaveConfig config; + + auto raft_config = config.mutable_raft(); + raft_config->set_election_ticks(4); + raft_config->set_heartbeat_ticks(2); + raft_config->set_replication_chunk_bytes(1 << 20); + + ::svr2::error::Error err = e0.Init(config); + CHECK(err == ::svr2::error::OK); + err = e1.Init(config); + CHECK(err == ::svr2::error::OK); + } + + e0.Connect(e1); + return 0; +} diff --git a/enclave/timeout/tests/timeout.cc b/enclave/timeout/tests/timeout.cc new file mode 100644 index 0000000..9fb4c2d --- /dev/null +++ b/enclave/timeout/tests/timeout.cc @@ -0,0 +1,87 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +//TESTDEP gtest +//TESTDEP timeout +//TESTDEP metrics +//TESTDEP util +//TESTDEP context +//TESTDEP env +//TESTDEP env/test +//TESTDEP env +//TESTDEP sip +//TESTDEP proto +//TESTDEP protobuf-lite +//TESTDEP libsodium + +#include +#include "timeout/timeout.h" +#include "context/context.h" +#include "env/env.h" +#include "env/test/test.h" + +namespace svr2::timeout { + +class TimeoutTest : public ::testing::Test { + protected: + static void SetUpTestCase() { + env::Init(); + } + + Timeout t; + context::Context ctx; +}; + +TEST_F(TimeoutTest, TicksStartAtZero) { + ASSERT_EQ(t.ticks(), 0); +} + +TEST_F(TimeoutTest, TimeoutRuns) { + bool ran = false; + t.SetTimeout(&ctx, 1, [&ran](context::Context* ctx){ ran = true; }); + ASSERT_FALSE(ran); + t.TimerTick(&ctx); + ASSERT_TRUE(ran); +} + +TEST_F(TimeoutTest, TimeoutCancels) { + bool ran = false; + Cancel c = t.SetTimeout(&ctx, 1, [&ran](context::Context* ctx){ ran = true; }); + ASSERT_FALSE(ran); + t.CancelTimeout(&ctx, c); + t.TimerTick(&ctx); + ASSERT_FALSE(ran); +} + +TEST_F(TimeoutTest, TimeoutCancelAfterRunIsFine) { + bool ran = false; + Cancel c = t.SetTimeout(&ctx, 1, [&ran](context::Context* ctx){ ran = true; }); + ASSERT_FALSE(ran); + t.TimerTick(&ctx); + ASSERT_TRUE(ran); + t.CancelTimeout(&ctx, c); +} + +TEST_F(TimeoutTest, MultipleTimeoutsAtSameTick) { + int ran = 0; + t.SetTimeout(&ctx, 1, [&ran](context::Context* ctx){ ran++; }); + Cancel c = t.SetTimeout(&ctx, 1, [&ran](context::Context* ctx){ ran++; }); + t.SetTimeout(&ctx, 1, [&ran](context::Context* ctx){ ran++; }); + t.SetTimeout(&ctx, 1, [&ran](context::Context* ctx){ ran++; }); + t.CancelTimeout(&ctx, c); + t.TimerTick(&ctx); + ASSERT_EQ(ran, 3); +} + +TEST_F(TimeoutTest, FarFutureTimeout) { + bool ran = false; + t.SetTimeout(&ctx, 1001, [&ran](context::Context* ctx){ ran = true; }); + for (int i = 0; i < 1000; i++) { + t.TimerTick(&ctx); + ASSERT_FALSE(ran); + } + t.TimerTick(&ctx); + ASSERT_TRUE(ran); +} + +} // namespace svr2::timeout diff --git a/enclave/timeout/timeout.cc b/enclave/timeout/timeout.cc new file mode 100644 index 0000000..e4287fc --- /dev/null +++ b/enclave/timeout/timeout.cc @@ -0,0 +1,58 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include "timeout/timeout.h" +#include "metrics/metrics.h" + +namespace svr2::timeout { + +Timeout::Timeout() : ticks_(0), timeout_cancel_gen_(0) {} + +void Timeout::TimerTick(context::Context* ctx) { + ACQUIRE_NAMED_LOCK(lock, mu_, ctx, lock_timeout); + ticks_++; + auto timeouts_to_run = timeouts_.find(ticks_); + if (timeouts_to_run == timeouts_.end()) { + return; + } + TimeoutSet ts = std::move(timeouts_to_run->second); + timeouts_.erase(timeouts_to_run); + // We unlock before calling timeout methods, since they may want to do things + // that also create timeouts. + lock.unlock(); + for (auto iter = ts.begin(); iter != ts.end(); ++iter) { + COUNTER(timeout, timeouts_run)->Increment(); + iter->second(ctx); + } +} + +Cancel Timeout::SetTimeout(context::Context* ctx, util::Ticks ticks_from_now, std::function fn) { + ACQUIRE_LOCK(mu_, ctx, lock_timeout); + CHECK(ticks_from_now + ticks_ > ticks_); + Cancel tc(ticks_from_now + ticks_, ++timeout_cancel_gen_); + auto finder = timeouts_.find(tc.at_tick_); + if (finder == timeouts_.end()) { + auto [i, b] = timeouts_.emplace( + tc.at_tick_, + std::unordered_map>()); + finder = i; + } + + finder->second[tc.cancel_id_] = fn; + COUNTER(timeout, timeouts_created)->Increment(); + return tc; +} + +void Timeout::CancelTimeout(context::Context* ctx, const Cancel& tc) { + ACQUIRE_LOCK(mu_, ctx, lock_timeout); + auto finder = timeouts_.find(tc.at_tick_); + if (finder != timeouts_.end()) { + auto f2 = finder->second.find(tc.cancel_id_); + if (f2 != finder->second.end()) { + COUNTER(timeout, timeouts_cancelled)->Increment(); + finder->second.erase(f2); + } + } +} + +} // namespace svr2::timeout diff --git a/enclave/timeout/timeout.h b/enclave/timeout/timeout.h new file mode 100644 index 0000000..de8c961 --- /dev/null +++ b/enclave/timeout/timeout.h @@ -0,0 +1,62 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_TIMEOUT_TIMEOUT_H__ +#define __SVR2_TIMEOUT_TIMEOUT_H__ + +#include +#include +#include +#include +#include +#include "util/ticks.h" +#include "context/context.h" + +namespace svr2::timeout { + +class Timeout; +class Cancel { + public: + Cancel() : at_tick_(0), cancel_id_(0) {} + private: + Cancel(util::Ticks at_tick, int64_t cancel_id) : at_tick_(at_tick), cancel_id_(cancel_id) {} + util::Ticks at_tick_; + int64_t cancel_id_; + friend class Timeout; +}; + +typedef std::function TimeoutFn; + +class Timeout { + public: + Timeout(); + // SetTimeout provides a function that will be called [ticks_from_now] ticks in the future (min 1). + // This function will be called at that time, once, unless CancelTimeout is called on the returned + // value before that time. + Cancel SetTimeout(context::Context* ctx, util::Ticks ticks_from_now, TimeoutFn fn) EXCLUDES(mu_); + // CancelTimeout cancels a function that was scheduled for the future. May be called any number + // of times on a Cancel, and may be called after the ticks for the given function have + // passed. + void CancelTimeout(context::Context* ctx, const Cancel& c) EXCLUDES(mu_); + // Called whenever the host gives us a TimerTick. + void TimerTick(context::Context* ctx) EXCLUDES(mu_); + +#ifdef IS_TEST + util::Ticks ticks() const EXCLUDES(mu_) { + util::unique_lock lock(mu_); + return ticks_; + } +#endif + + private: + // Time and Timeouts + mutable util::mutex mu_; + util::Ticks ticks_ GUARDED_BY(mu_); + int64_t timeout_cancel_gen_ GUARDED_BY(mu_); + typedef std::unordered_map TimeoutSet; + std::unordered_map timeouts_ GUARDED_BY(mu_); +}; + +} // namespace svr2::timeout + +#endif // __SVR2_TIMEOUT_TIMEOUT_H__ diff --git a/enclave/util/bytes.h b/enclave/util/bytes.h new file mode 100644 index 0000000..baf4f06 --- /dev/null +++ b/enclave/util/bytes.h @@ -0,0 +1,36 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_UTIL_BYTES_H +#define __SVR2_UTIL_BYTES_H + +#include +#include +#include +#include +#include "util/macros.h" +#include "proto/error.pb.h" + +namespace svr2::util { + +template +std::string ByteArrayToString(const std::array& bytes) { + std::string result; + result.resize(N, '\0'); + std::copy(bytes.begin(), bytes.end(), result.begin()); + return result; +} + +template +std::pair, error::Error> StringToByteArray(const std::string& str) { + std::array result{0}; + if (str.size() > N) { + return std::make_pair(result, error::Util_ArrayCopyTooBig); + } + std::copy(str.begin(), str.end(), result.begin()); + return std::make_pair(result, error::OK); +} + +} // namespace svr2::util + +#endif // __SVR2_UTIL_BYTES_H diff --git a/enclave/util/constant.h b/enclave/util/constant.h new file mode 100644 index 0000000..76940ee --- /dev/null +++ b/enclave/util/constant.h @@ -0,0 +1,31 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_UTIL_CONSTANT_H__ +#define __SVR2_UTIL_CONSTANT_H__ + +namespace svr2::util { + +// Templatized to work on std::array and std::string. +template +static bool ConstantTimeEqualsPrefix(const T1& a, const T2& b, size_t prefix_size) { + if (a.size() < prefix_size || b.size() < prefix_size) return false; // not constant time, but we generally don't care. + const uint8_t* aptr = reinterpret_cast(a.data()); + const uint8_t* bptr = reinterpret_cast(b.data()); + uint8_t out = 0; + while (prefix_size--) { + out |= (*aptr++) ^ (*bptr++); + } + return out == 0; +} + +// Templatized to work on std::array and std::string. +template +static bool ConstantTimeEquals(const T1& a, const T2& b) { + if (a.size() != b.size()) return false; // not constant time, but we generally don't care. + return ConstantTimeEqualsPrefix(a, b, a.size()); +} + +} // namespace svr2::util + +#endif // __SVR2_UTIL_CONSTANT_H__ diff --git a/enclave/util/cpu.h b/enclave/util/cpu.h new file mode 100644 index 0000000..37dca01 --- /dev/null +++ b/enclave/util/cpu.h @@ -0,0 +1,21 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_UTIL_CPU_H__ +#define __SVR2_UTIL_CPU_H__ + +#include + +namespace svr2::util { + +// `rdtsc` gets the current CPU ticks from the current CPU. +uint64_t asm_rdtsc(); +inline uint64_t asm_rdtsc() { + uint64_t lo, hi; + asm volatile( "rdtsc" : "=a" (lo), "=d" (hi) ); + return lo | ( hi << 32 ); +} + +} // namespace svr2::util + +#endif // __SVR2_UTIL_CPU_H__ diff --git a/enclave/util/endian.h b/enclave/util/endian.h new file mode 100644 index 0000000..c0cd4f0 --- /dev/null +++ b/enclave/util/endian.h @@ -0,0 +1,62 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_UTIL_ENDIAN_H__ +#define __SVR2_UTIL_ENDIAN_H__ + +namespace svr2::util { + +inline uint64_t BigEndian64FromBytes(const uint8_t in[8]) { + return ((uint64_t)in[0]) << (8*7) | + ((uint64_t)in[1]) << (8*6) | + ((uint64_t)in[2]) << (8*5) | + ((uint64_t)in[3]) << (8*4) | + ((uint64_t)in[4]) << (8*3) | + ((uint64_t)in[5]) << (8*2) | + ((uint64_t)in[6]) << (8*1) | + ((uint64_t)in[7]) << (8*0); +} + +inline uint32_t BigEndian32FromBytes(const uint8_t in[4]) { + return ((uint32_t)in[0]) << (8*3) | + ((uint32_t)in[1]) << (8*2) | + ((uint32_t)in[2]) << (8*1) | + ((uint32_t)in[3]) << (8*0); +} + +inline uint64_t BigEndian64FromBytes(const char* in) { + return BigEndian64FromBytes(reinterpret_cast(in)); +} + +inline void BigEndian64Bytes(uint64_t v, uint8_t out[8]) { + out[0] = v >> (8*7); + out[1] = v >> (8*6); + out[2] = v >> (8*5); + out[3] = v >> (8*4); + out[4] = v >> (8*3); + out[5] = v >> (8*2); + out[6] = v >> (8*1); + out[7] = v >> (8*0); +} + +inline void BigEndian32Bytes(uint32_t v, uint8_t out[4]) { + out[0] = v >> (8*3); + out[1] = v >> (8*2); + out[2] = v >> (8*1); + out[3] = v >> (8*0); +} + +inline void LittleEndian64Bytes(uint64_t v, uint8_t out[8]) { + out[0] = v >> (8*0); + out[1] = v >> (8*1); + out[2] = v >> (8*2); + out[3] = v >> (8*3); + out[4] = v >> (8*4); + out[5] = v >> (8*5); + out[6] = v >> (8*6); + out[7] = v >> (8*7); +} + +} // namespace svr2::util + +#endif // __SVR2_UTIL_ENDIAN_H__ diff --git a/enclave/util/hex.cc b/enclave/util/hex.cc new file mode 100644 index 0000000..6c8684d --- /dev/null +++ b/enclave/util/hex.cc @@ -0,0 +1,30 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include "util/hex.h" + +namespace svr2::util { + +std::string BytesToHex(const uint8_t* in, size_t size) { + static const char* nibbles = "0123456789abcdef"; + std::string out(size * 2, ' '); + for (size_t i = 0; i < size; i++) { + out[i*2+0] = nibbles[(in[i] & 0xf0) >> 4]; + out[i*2+1] = nibbles[(in[i] & 0x0f) >> 0]; + } + return out; +} + +std::string HexToBytes(std::string hex) { + std::string bytes; + + for (unsigned int i = 0; i < hex.length(); i += 2) { + std::string byteString = hex.substr(i, 2); + char byte = (char) strtol(byteString.c_str(), NULL, 16); + bytes.push_back(byte); + } + + return bytes; +} + +} // namespace svr2::util diff --git a/enclave/util/hex.h b/enclave/util/hex.h new file mode 100644 index 0000000..87309f9 --- /dev/null +++ b/enclave/util/hex.h @@ -0,0 +1,29 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_UTIL_HEX_H__ +#define __SVR2_UTIL_HEX_H__ + +#include +#include + +namespace svr2::util { + +std::string BytesToHex(const uint8_t* in, size_t size); +std::string HexToBytes(std::string hex); + +// Turns the `s`-byte prefix of `in` into `s*2` hex characters and returns it as a string. +template +std::string PrefixToHex(const T& in, size_t s) { + return BytesToHex(reinterpret_cast(in.data()), std::min(s, in.size())); +} + +// Turns the bytes of `in` into `in.size()*2` hex characters and returns it as a string. +template +std::string ToHex(const T& in) { + return PrefixToHex(in, in.size()); +} + +} // namespace svr2::util + +#endif // __SVR2_UTIL_HEX_H__ diff --git a/enclave/util/log.cc b/enclave/util/log.cc new file mode 100644 index 0000000..ec581e5 --- /dev/null +++ b/enclave/util/log.cc @@ -0,0 +1,33 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include "util/log.h" +#include "env/env.h" +#include "util/macros.h" + +namespace svr2::util { + +::svr2::enclaveconfig::EnclaveLogLevel log_level_to_write = +#ifdef IS_TEST + enclaveconfig::LOG_LEVEL_MAX; +#else + enclaveconfig::LOG_LEVEL_INFO; +#endif + +Log::Log(::svr2::enclaveconfig::EnclaveLogLevel lvl) : lvl_(lvl) {} + +Log::~Log() { + env::environment->Log(lvl_, ss_.str()); + if (lvl_ == enclaveconfig::LOG_LEVEL_FATAL) { CHECK(false); } +} + +void SetLogLevel(::svr2::enclaveconfig::EnclaveLogLevel level) { + log_level_to_write = level; +} + +} // namespace svr2::util + +std::ostream& operator<<(std::ostream& os, ::svr2::error::Error err) { + os << ::svr2::error::Error_Name(err); + return os; +} diff --git a/enclave/util/log.h b/enclave/util/log.h new file mode 100644 index 0000000..6d05305 --- /dev/null +++ b/enclave/util/log.h @@ -0,0 +1,41 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_UTIL_LOG_H__ +#define __SVR2_UTIL_LOG_H__ + +#include +#include +#include +#include "proto/error.pb.h" +#include "proto/msgs.pb.h" + +std::ostream& operator<<(std::ostream& os, ::svr2::error::Error err); + +namespace svr2::util { + +class Log { + public: + Log(::svr2::enclaveconfig::EnclaveLogLevel lvl); + ~Log(); + + template + std::ostream& operator<<(T x) { + ss_ << x; + return ss_; + } + + private: + ::svr2::enclaveconfig::EnclaveLogLevel lvl_; + std::stringstream ss_; +}; + +extern ::svr2::enclaveconfig::EnclaveLogLevel log_level_to_write; + +void SetLogLevel(::svr2::enclaveconfig::EnclaveLogLevel level); + +} // namespace svr2::util + +#define LOG(x) if (::svr2::enclaveconfig::LOG_LEVEL_##x <= ::svr2::util::log_level_to_write) ::svr2::util::Log(::svr2::enclaveconfig::LOG_LEVEL_##x) << #x << "\t" << __FILE__ << ":" << __LINE__ << "(" << __FUNCTION__ << ") T=" << std::this_thread::get_id() << " - " + +#endif // __SVR2_UTIL_LOG_H__ diff --git a/enclave/util/macros.h b/enclave/util/macros.h new file mode 100644 index 0000000..6d598af --- /dev/null +++ b/enclave/util/macros.h @@ -0,0 +1,34 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_UTIL_MACROS_H__ +#define __SVR2_UTIL_MACROS_H__ + +#include +#include + +#define CHECK(x) do { \ + if (!(x)) { \ + fprintf(stderr, "CHECK FAIL @ %s:%d in %s: %s\n", __FILE__, __LINE__, __FUNCTION__, #x); \ + abort(); \ + } \ +} while (0) + +#define RETURN_IF_ERROR(x) do { \ + ::svr2::error::Error _err_ = (x); \ + if (_err_ != ::svr2::error::OK) return _err_; \ +} while (0) + +#define DELETE_COPY_AND_ASSIGN(x) \ + x(x& other) = delete; \ + void operator=(const x &) = delete +#define DELETE_ASSIGN(x) \ + void operator=(const x &) = delete + +#ifdef IS_TEST +#define public_for_test public +#else +#define public_for_test private +#endif + +#endif // __SVR2_UTIL_MACROS_H__ diff --git a/enclave/util/mutex.h b/enclave/util/mutex.h new file mode 100644 index 0000000..b7152d9 --- /dev/null +++ b/enclave/util/mutex.h @@ -0,0 +1,52 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_UTIL_MUTEX_H__ +#define __SVR2_UTIL_MUTEX_H__ + +#include +#include "util/threadsafetyannotations.h" +#include "util/macros.h" + +namespace svr2::util { + +// These classes are simple wrappers around equivalent std::xxx classes, +// except they've been augmented with Clang thread safety annotations +// for static analysis of locking. + +class CAPABILITY("mutex") mutex { + public: + DELETE_COPY_AND_ASSIGN(mutex); + mutex() {} + inline void lock() ACQUIRE() { mu_.lock(); } + inline void unlock() RELEASE() { mu_.unlock(); } + inline bool try_lock() TRY_ACQUIRE(true) { return mu_.try_lock(); } + + // For negative thread safety analysis capabilities only. + const mutex& operator!() const { + CHECK(nullptr == "this function should be used only for thread annotations"); + return *this; + } + + private: + std::mutex mu_; +}; + +template +class SCOPED_CAPABILITY unique_lock { + public: + DELETE_COPY_AND_ASSIGN(unique_lock); + unique_lock(T& mu) ACQUIRE(mu) : mu_(mu), locked_(true) { mu_.lock(); } + ~unique_lock() RELEASE() { if (locked_) mu_.unlock(); } + unique_lock(T& mu, std::defer_lock_t d) EXCLUDES(mu) : mu_(mu), locked_(false) { } + inline void lock() ACQUIRE() { mu_.lock(); locked_ = true; } + inline void unlock() RELEASE() { mu_.unlock(); locked_ = false; } + + private: + T& mu_; + bool locked_; +}; + +} // namespace svr2::util + +#endif // __SVR2_UTIL_MUTEX_H__ diff --git a/enclave/util/tests/constant.cc b/enclave/util/tests/constant.cc new file mode 100644 index 0000000..70694de --- /dev/null +++ b/enclave/util/tests/constant.cc @@ -0,0 +1,28 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +//TESTDEP gtest +//TESTDEP util +#include +#include "util/constant.h" +#include +#include + +namespace svr2::util { + +class ConstantTest : public ::testing::Test {}; + +TEST_F(ConstantTest, Equality) { + std::string a("abc"); + std::array b{'a', 'b', 'c'}; + EXPECT_TRUE(ConstantTimeEquals(a, a)); + EXPECT_TRUE(ConstantTimeEquals(a, b)); + EXPECT_TRUE(ConstantTimeEquals(b, a)); + std::string c("aBc"); + EXPECT_FALSE(ConstantTimeEquals(a, c)); + EXPECT_FALSE(ConstantTimeEquals(c, a)); + EXPECT_FALSE(ConstantTimeEquals(b, c)); + EXPECT_FALSE(ConstantTimeEquals(c, b)); +} + +} // namespace svr2::util diff --git a/enclave/util/tests/endian.cc b/enclave/util/tests/endian.cc new file mode 100644 index 0000000..c68c5a9 --- /dev/null +++ b/enclave/util/tests/endian.cc @@ -0,0 +1,31 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +//TESTDEP gtest +//TESTDEP util +#include +#include "util/endian.h" +#include +#include + +namespace svr2::util { + +class EndianTest : public ::testing::Test {}; + +TEST_F(EndianTest, BigEndian64RoundTrip) { + uint8_t buf[8] = {0}; + BigEndian64Bytes(0xfedcba9876543210ULL, buf); + ASSERT_EQ(BigEndian64FromBytes(buf), 0xfedcba9876543210ULL); + uint8_t expected[8] = {0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10}; + ASSERT_EQ(0, memcmp(buf, expected, 8)); +} + +TEST_F(EndianTest, BigEndian32RoundTrip) { + uint8_t buf[4] = {0}; + BigEndian32Bytes(0xfedc4321UL, buf); + ASSERT_EQ(BigEndian32FromBytes(buf), 0xfedc4321UL); + uint8_t expected[8] = {0xfe, 0xdc, 0x43, 0x21}; + ASSERT_EQ(0, memcmp(buf, expected, 4)); +} + +} // namespace svr2::util diff --git a/enclave/util/tests/hex.cc b/enclave/util/tests/hex.cc new file mode 100644 index 0000000..ffd696e --- /dev/null +++ b/enclave/util/tests/hex.cc @@ -0,0 +1,29 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +//TESTDEP gtest +//TESTDEP util +#include +#include "util/hex.h" +#include + +namespace svr2::util { + +class HexTest : public ::testing::Test {}; + +TEST_F(HexTest, ToHex) { + std::string a("\x01\x02\x0a"); + std::array b{4, 0x3b, 0xff}; + EXPECT_EQ("01020a", ToHex(a)); + EXPECT_EQ("043bff", ToHex(b)); +} + +TEST_F(HexTest, PrefixToHex) { + std::array b{4, 0x3b, 0xff}; + EXPECT_EQ("", PrefixToHex(b, 0)); + EXPECT_EQ("043b", PrefixToHex(b, 2)); + EXPECT_EQ("043bff", PrefixToHex(b, 3)); + EXPECT_EQ("043bff", PrefixToHex(b, 4)); +} + +} // namespace svr2::util diff --git a/enclave/util/threadsafetyannotations.h b/enclave/util/threadsafetyannotations.h new file mode 100644 index 0000000..bc19bc8 --- /dev/null +++ b/enclave/util/threadsafetyannotations.h @@ -0,0 +1,74 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_UTIL_THREADSAFETYANNOTATIONS_H__ +#define __SVR2_UTIL_THREADSAFETYANNOTATIONS_H__ + +// Taken from https://releases.llvm.org/11.0.0/tools/clang/docs/ThreadSafetyAnalysis.html + +// Enable thread safety attributes only with clang. +// The attributes can be safely erased when compiling with other compilers. +#if defined(__clang__) && (!defined(SWIG)) +#define THREAD_ANNOTATION_ATTRIBUTE__(x) __attribute__((x)) +#else +#define THREAD_ANNOTATION_ATTRIBUTE__(x) // no-op +#endif + +#define CAPABILITY(x) \ + THREAD_ANNOTATION_ATTRIBUTE__(capability(x)) + +#define SCOPED_CAPABILITY \ + THREAD_ANNOTATION_ATTRIBUTE__(scoped_lockable) + +#define GUARDED_BY(x) \ + THREAD_ANNOTATION_ATTRIBUTE__(guarded_by(x)) + +#define PT_GUARDED_BY(x) \ + THREAD_ANNOTATION_ATTRIBUTE__(pt_guarded_by(x)) + +#define ACQUIRED_BEFORE(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(acquired_before(__VA_ARGS__)) + +#define ACQUIRED_AFTER(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(acquired_after(__VA_ARGS__)) + +#define REQUIRES(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(requires_capability(__VA_ARGS__)) + +#define REQUIRES_SHARED(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(requires_shared_capability(__VA_ARGS__)) + +#define ACQUIRE(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(acquire_capability(__VA_ARGS__)) + +#define ACQUIRE_SHARED(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(acquire_shared_capability(__VA_ARGS__)) + +#define RELEASE(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(release_capability(__VA_ARGS__)) + +#define RELEASE_SHARED(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(release_shared_capability(__VA_ARGS__)) + +#define TRY_ACQUIRE(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(try_acquire_capability(__VA_ARGS__)) + +#define TRY_ACQUIRE_SHARED(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(try_acquire_shared_capability(__VA_ARGS__)) + +#define EXCLUDES(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(locks_excluded(__VA_ARGS__)) + +#define ASSERT_CAPABILITY(x) \ + THREAD_ANNOTATION_ATTRIBUTE__(assert_capability(x)) + +#define ASSERT_SHARED_CAPABILITY(x) \ + THREAD_ANNOTATION_ATTRIBUTE__(assert_shared_capability(x)) + +#define RETURN_CAPABILITY(x) \ + THREAD_ANNOTATION_ATTRIBUTE__(lock_returned(x)) + +#define NO_THREAD_SAFETY_ANALYSIS \ + THREAD_ANNOTATION_ATTRIBUTE__(no_thread_safety_analysis) + +#endif // __SVR2_UTIL_THREADSAFETYANNOTATIONS_H__ diff --git a/enclave/util/ticks.cc b/enclave/util/ticks.cc new file mode 100644 index 0000000..4dc8772 --- /dev/null +++ b/enclave/util/ticks.cc @@ -0,0 +1,8 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#include "util/ticks.h" + +namespace svr2::util { +const Ticks InvalidTicks = INT64_MAX; +} // namespace svr2::util diff --git a/enclave/util/ticks.h b/enclave/util/ticks.h new file mode 100644 index 0000000..65bd114 --- /dev/null +++ b/enclave/util/ticks.h @@ -0,0 +1,18 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +#ifndef __SVR2_UTIL_TICKS_H__ +#define __SVR2_UTIL_TICKS_H__ + +#include +#include + +namespace svr2::util { + +typedef int64_t Ticks; +extern const Ticks InvalidTicks; +typedef time_t UnixSecs; + +} // namespace svr2::util + +#endif // __SVR2_UTIL_TICKS_H__ diff --git a/host/.gitignore b/host/.gitignore new file mode 100644 index 0000000..cf2d3d5 --- /dev/null +++ b/host/.gitignore @@ -0,0 +1,5 @@ +*.pb.go +main +enclave/c +enclave/enclave.test +cmd/control/control diff --git a/host/.tool-versions b/host/.tool-versions new file mode 100644 index 0000000..fe2801d --- /dev/null +++ b/host/.tool-versions @@ -0,0 +1 @@ +golang 1.19.2 diff --git a/host/Makefile b/host/Makefile new file mode 100644 index 0000000..9d46b45 --- /dev/null +++ b/host/Makefile @@ -0,0 +1,65 @@ +OE_INCDIR = $(shell pkg-config oeenclave-clang++ --variable=includedir) +CC=clang-11 +GO_TEST_FLAGS ?= + +all: build test +.PHONY: all clean protos build test ../enclave/build/enclave.test validate + +build: generated + go build main.go + +../enclave/build/enclave.test: + $(MAKE) -C ../enclave build/enclave.test + +# -count=1 forces this test to run un-cached, since enclave.test may have changed +# even though Go code has not, and tests may depend on it. +test: build | ../enclave/build/enclave.test + go test $(GO_TEST_FLAGS) -count=1 ./... + +EDGER8R_FILES=enclave/c/svr2_u.c enclave/c/svr2_u.h enclave/c/svr2_args.h +# This $(firstword) trick allows for grouped targets. +$(filter-out $(firstword $(EDGER8R_FILES)),$(EDGER8R_FILES)): $(firstword $(EDGER8R_FILES)) +$(firstword $(EDGER8R_FILES)): ../shared/svr2.edl + mkdir -p enclave/c + oeedger8r $< --untrusted \ + --untrusted-dir enclave/c \ + --search-path $(OE_INCDIR) \ + --search-path $(OE_INCDIR)/openenclave/edl/sgx +enclave/c/libsvr2.a: $(EDGER8R_FILES) + $(CC) -c -o enclave/c/svr2.o $(shell pkg-config oehost-clang --cflags) enclave/c/svr2_u.c + ar rcs $@ enclave/c/svr2.o + +PROTO_FILES = \ + $(patsubst ../shared/proto/%.proto,proto/%.pb.go,$(wildcard ../shared/proto/*.proto)) \ + $(patsubst proto/%.proto,proto/%.pb.go,$(wildcard proto/*.proto)) \ +## PROTO_FILES +# This $(firstword) trick allows for grouped targets. +$(filter-out $(firstword $(PROTO_FILES)),$(PROTO_FILES)): $(firstword $(PROTO_FILES)) +$(firstword $(PROTO_FILES)): ../shared/proto/*.proto proto/*.proto + protoc --go_out=. --go_opt=module=github.com/signalapp/svr2 --proto_path=../shared/proto ../shared/proto/*.proto + protoc --go_out=. --go_opt=module=github.com/signalapp/svr2 --proto_path=../shared/proto --proto_path=proto proto/*.proto +protos: $(PROTO_FILES) +generated: protos enclave/c/libsvr2.a + +validate: generated + go vet ./... + [ -z "$$(go fmt ./...)" ] + +clean: + rm -vfr enclave/c + rm -vfr main + rm -vf proto/*.pb.go + rm -vf .test_enclave + rm -vf enclave/enclave.test + +enclave/enclave.test: generated + (cd enclave && go test -c) + +enclave_test_gdb: enclave/enclave.test ../enclave/build/enclave.test + (cd enclave && /opt/openenclave/bin/oegdb enclave.test) + +miniredis: + (cd miniredis && go build miniredis.go) + +control: generated + (cd cmd/control && go build) diff --git a/host/README.md b/host/README.md new file mode 100644 index 0000000..ed57c53 --- /dev/null +++ b/host/README.md @@ -0,0 +1,18 @@ +# SVR2 Host Code (the Go attempt) + +This codebase provides a host-side binary for running and interacting with +an enclave, while also interacting with the outside world (external services, +clients, etc), and acts as a bridge between these two worlds. It follows +typical Go paradigms. + +## Go Version + +Currently, this code is targetted to Go 1.19, so yay, generics are a thing. + +## Testing + +There should be some. + +## Formatting + +One `gofmt` to rule them all. diff --git a/host/auth/auth.go b/host/auth/auth.go new file mode 100644 index 0000000..b11b060 --- /dev/null +++ b/host/auth/auth.go @@ -0,0 +1,101 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +// Package auth provides for the ability to authenticate clients using +// basic auth credentials they get from Signal chat servers' +// ExternalServiceCredentialsGenerator. +package auth + +import ( + "crypto/hmac" + "crypto/sha256" + "crypto/subtle" + "encoding/hex" + "fmt" + "strconv" + "strings" + "time" + + "github.com/signalapp/svr2/util" +) + +const ( + authenticationTokenMaxAgeSeconds = 30 * 86400 +) + +// Auth allows us to check a username and password, or generate a password for a user. +type Auth interface { + // Check returns nil if this user/pass combination is legitimate. + // Otherwise, it returns an error describing the reason it's invalid. + Check(user, pass string) error + // PassFor returns a valid password for a given user at the current time. + PassFor(user string) string +} + +// New returns a new production Auth based on the given secret and expiration. +func New(secret []byte) Auth { + return &auth{secret: secret, clock: util.RealClock, expiration: time.Second * authenticationTokenMaxAgeSeconds} +} + +type alwaysAllow struct{} + +func (a alwaysAllow) Check(user, pass string) error { + return nil +} +func (a alwaysAllow) PassFor(user string) string { + return "wheee" +} + +// AlwaysAllow provides an Auth that will always allow clients to connect. +var AlwaysAllow = Auth(alwaysAllow{}) + +type auth struct { + secret []byte + clock util.Clock + expiration time.Duration +} + +func (a *auth) Check(user, pass string) error { + ts, sig, err := a.parsePass(pass) + if err != nil { + return err + } + return a.valid(user, ts, sig) +} + +func (a *auth) parsePass(pass string) (ts time.Time, sig []byte, _ error) { + i := strings.Index(pass, ":") + if i < 0 { + return time.Time{}, nil, fmt.Errorf("no separator") + } + unixSecs, err := strconv.ParseInt(pass[:i], 10, 64) + if err != nil { + return time.Time{}, nil, fmt.Errorf("parsing timestamp: %v", err) + } + ts = time.Unix(unixSecs, 0) + sig, err = hex.DecodeString(pass[i+1:]) + return ts, sig, err +} + +func (a *auth) valid(user string, ts time.Time, sig []byte) error { + diff := a.clock.Now().Sub(ts) + if diff > a.expiration || diff < -a.expiration { + return fmt.Errorf("expired") + } + mac := hmac.New(sha256.New, a.secret) + fmt.Fprintf(mac, "%s:%d", user, ts.Unix()) + var sum [sha256.Size]byte + mac.Sum(sum[:0]) + if subtle.ConstantTimeCompare(sum[:10], sig) != 1 { + return fmt.Errorf("mac failure") + } + return nil +} + +func (a *auth) PassFor(user string) string { + ts := a.clock.Now() + mac := hmac.New(sha256.New, a.secret) + fmt.Fprintf(mac, "%s:%d", user, ts.Unix()) + key := mac.Sum(nil)[:10] + return fmt.Sprintf("%d:%x", ts.Unix(), key) +} diff --git a/host/auth/auth_test.go b/host/auth/auth_test.go new file mode 100644 index 0000000..3d2e5aa --- /dev/null +++ b/host/auth/auth_test.go @@ -0,0 +1,45 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package auth + +import ( + "fmt" + "testing" + "time" + + "github.com/signalapp/svr2/util" +) + +func TestAuthWorks(t *testing.T) { + a := &auth{ + secret: []byte{1, 2, 3, 4, 5}, + clock: util.TestAt(time.Unix(10000, 0)), + expiration: 3600 * time.Second, + } + for _, test := range []struct{ user, pass string }{ + {user: "12345", pass: "10000:8b2df41718f48f312c6d"}, + {user: "12345", pass: "13600:fb1e57e272683fb785b1"}, + {user: "12345", pass: "6400:614c7129a946e79c83ed"}, + {user: "123456", pass: "10000:9a08d531879caa2a81f0"}, + {user: "wizzle", pass: a.PassFor("wizzle")}, + } { + t.Logf("%+v", test) + if err := a.Check(test.user, test.pass); err != nil { + t.Errorf("expected check success, got error: %v", err) + } + } + + validPass := []byte{0x8b, 0x2d, 0xf4, 0x17, 0x18, 0xf4, 0x8f, 0x31, 0x2c, 0x6d} + for i := 0; i < len(validPass); i++ { + for j := 0; j < 8; j++ { + b := make([]byte, len(validPass)) + copy(b, validPass) + b[i] ^= 1 << j + badPass := fmt.Sprintf("10000:%x", b) + if err := a.Check("12345", badPass); err == nil { + t.Errorf("bitflipped pass, want error got success: valid=%x ours=%q", validPass, badPass) + } + } + } +} diff --git a/host/cmd/control/main.go b/host/cmd/control/main.go new file mode 100644 index 0000000..98b54a0 --- /dev/null +++ b/host/cmd/control/main.go @@ -0,0 +1,73 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package main + +import ( + "flag" + "fmt" + "os" + + "github.com/signalapp/svr2/web/client" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + + pb "github.com/signalapp/svr2/proto" +) + +var ( + addr = flag.String("addr", "localhost:8081", "Address (hostname:port) where control server is listening") + binary = flag.Bool("bin", false, "If true, assume a binary formatted proto file. Otherwise, protojson") +) + +func main() { + flag.Usage = func() { + fmt.Fprint(flag.CommandLine.Output(), "Issue a control command for a HostToEnclaveRequest proto \n") + fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [flags] proto_filename \n", os.Args[0]) + flag.PrintDefaults() + } + flag.Parse() + if flag.NArg() == 0 { + flag.Usage() + os.Exit(1) + } + + bs, err := requestBody(flag.Args()[0]) + if err != nil { + fmt.Fprint(os.Stderr, err) + os.Exit(1) + } + + cc := client.ControlClient{Addr: *addr} + resp, err := cc.DoJSON(bs) + if err != nil { + fmt.Fprint(os.Stderr, err) + os.Exit(1) + } + fmt.Fprintln(os.Stderr, "successfully executed control request") + fmt.Println(protojson.Format(resp)) +} + +func requestBody(filename string) ([]byte, error) { + bs, err := os.ReadFile(filename) + if err != nil { + return nil, fmt.Errorf("failed to read file : %v", err) + } + + request := pb.HostToEnclaveRequest{} + if *binary { + err = proto.Unmarshal(bs, &request) + } else { + err = protojson.Unmarshal(bs, &request) + } + if err != nil { + return nil, fmt.Errorf("failed to parse proto : %v", err) + } + if !*binary { + return bs, nil + } + if bs, err = protojson.Marshal(&request); err != nil { + return nil, fmt.Errorf("failed to marshal proto : %v", err) + } + return bs, nil +} diff --git a/host/cmd/svr2client/main.go b/host/cmd/svr2client/main.go new file mode 100644 index 0000000..fba985d --- /dev/null +++ b/host/cmd/svr2client/main.go @@ -0,0 +1,261 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package main + +import ( + "crypto/rand" + "crypto/sha256" + "crypto/tls" + "encoding/base64" + "encoding/hex" + "flag" + "fmt" + "log" + "net/http" + "net/url" + "os" + "sync" + "sync/atomic" + + "github.com/gorilla/websocket" + "github.com/signalapp/svr2/auth" + "github.com/signalapp/svr2/web/client" + + pb "github.com/signalapp/svr2/proto" +) + +var ( + backupCmd = flag.NewFlagSet("backup", flag.ExitOnError) + exposeCmd = flag.NewFlagSet("expose", flag.ExitOnError) + restoreCmd = flag.NewFlagSet("restore", flag.ExitOnError) + deleteCmd = flag.NewFlagSet("delete", flag.ExitOnError) + loadtestCmd = flag.NewFlagSet("loadtest", flag.ExitOnError) + + user = toUser("test123") + host, enclaveID, authKey string + useTLS bool +) + +var subcommands = map[string]*flag.FlagSet{ + backupCmd.Name(): backupCmd, + exposeCmd.Name(): exposeCmd, + restoreCmd.Name(): restoreCmd, + deleteCmd.Name(): deleteCmd, + loadtestCmd.Name(): loadtestCmd, +} + +func main() { + for _, fs := range subcommands { + fs.StringVar(&host, "host", "svr2.staging.signal.org", "endpoint to connect to") + fs.StringVar(&enclaveID, "enclaveId", "7d44d147f38d102c2874ffcd92302398ac2b38592633bb20c75dce9c171fe877", "mrenclave to use") + fs.StringVar(&authKey, "authKey", "", "base64 encoded shared svr auth key") + fs.Func("user", "basic auth username. If it's not a 32 character hex string it will be hashed", func(s string) error { + user = toUser(s) + return nil + }) + fs.BoolVar(&useTLS, "useTLS", true, "whether to use TLS") + } + + switch os.Args[1] { + case backupCmd.Name(): + backupCmd.Parse(os.Args[2:]) + if err := runBackup(user); err != nil { + fmt.Fprint(os.Stderr, err.Error()) + os.Exit(1) + } + case exposeCmd.Name(): + exposeCmd.Parse(os.Args[2:]) + if err := runExpose(user); err != nil { + fmt.Fprint(os.Stderr, err.Error()) + os.Exit(1) + } + case restoreCmd.Name(): + var pin string + restoreCmd.StringVar(&pin, "pin", "", "pin") + restoreCmd.Parse(os.Args[2:]) + if err := runRestore(pin); err != nil { + fmt.Fprint(os.Stderr, err.Error()) + os.Exit(1) + } + case deleteCmd.Name(): + deleteCmd.Parse(os.Args[2:]) + if err := runDelete(); err != nil { + fmt.Fprint(os.Stderr, err.Error()) + os.Exit(1) + } + case loadtestCmd.Name(): + parallel := loadtestCmd.Int("parallel", 1, "amount of parallelization") + count := loadtestCmd.Int("count", 1, "total count to run") + loadtestCmd.Parse(os.Args[2:]) + if err := runLoadTest(*parallel, *count); err != nil { + fmt.Fprint(os.Stderr, err.Error()) + os.Exit(1) + } + } +} + +func toUser(usernameRaw string) string { + bs, err := hex.DecodeString(usernameRaw) + if err == nil && len(bs) == 16 { + return usernameRaw + } + h := sha256.Sum256([]byte(usernameRaw)) + return hex.EncodeToString(h[:16]) +} + +func newClient(username string) (*client.SVR2Client, error) { + u := url.URL{Scheme: "wss", Host: host, Path: fmt.Sprintf("v1/%s", enclaveID)} + if !useTLS { + u.Scheme = "ws" + } + log.Printf("%v as %v", u, username) + dialer := *websocket.DefaultDialer + if useTLS { + dialer.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + } + authBytes, err := base64.StdEncoding.DecodeString(authKey) + if err != nil { + return nil, err + } + c, resp, err := dialer.Dial(u.String(), http.Header{ + "Authorization": []string{"Basic " + base64.URLEncoding.EncodeToString([]byte(username+":"+auth.New(authBytes).PassFor(username)))}, + }) + if err != nil { + return nil, fmt.Errorf("dial %v", err) + } else if resp.StatusCode > 299 { + return nil, fmt.Errorf("code %v", resp.Status) + } + + return client.NewClient(c) +} + +func runRestore(hexPin string) error { + c, err := newClient(user) + if err != nil { + return err + } + pin, err := hex.DecodeString(hexPin) + if err != nil { + return err + } + + r, err := c.Send(&pb.Request{Inner: &pb.Request_Restore{ + Restore: &pb.RestoreRequest{ + Pin: pin, + }, + }}) + if err != nil { + return err + } + log.Print(r) + return nil + +} + +func runLoadTest(parallel, count int) error { + countU32 := int32(count) + var wg sync.WaitGroup + for i := 0; i < parallel; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for { + u := atomic.AddInt32(&countU32, -1) + if u < 0 { + return + } + user := toUser(fmt.Sprintf("%s_%d", user, u)) + if err := runBackup(user); err != nil { + log.Printf("user %d failed backup: %v", u, err) + } + if err := runExpose(user); err != nil { + log.Printf("user %d failed expose: %v", u, err) + } + } + }() + } + wg.Wait() + return nil +} + +func bytesForUser(username string) []byte { + h := sha256.Sum256([]byte(username)) + return h[:] +} + +func runBackup(username string) error { + c, err := newClient(username) + if err != nil { + return err + } + + b := bytesForUser(username) + r, err := c.Send(&pb.Request{Inner: &pb.Request_Backup{ + Backup: &pb.BackupRequest{ + Data: b, + Pin: b, + MaxTries: 5, + }, + }}) + if err != nil { + return err + } + br, ok := r.Inner.(*pb.Response_Backup) + if !ok { + return fmt.Errorf("unexpected response : %v", r) + } + if br.Backup.Status != pb.BackupResponse_OK { + return fmt.Errorf("backup request not successful: %v", br.Backup.Status) + } + log.Printf("successful: data=pin=%x", b) + return nil +} + +func runExpose(username string) error { + c, err := newClient(username) + if err != nil { + return err + } + + b := bytesForUser(username) + r, err := c.Send(&pb.Request{Inner: &pb.Request_Expose{ + Expose: &pb.ExposeRequest{ + Data: b, + }, + }}) + if err != nil { + return err + } + br, ok := r.Inner.(*pb.Response_Expose) + if !ok { + return fmt.Errorf("unexpected response : %v", r) + } + if br.Expose.Status != pb.ExposeResponse_OK { + return fmt.Errorf("backup request not successful: %v", br.Expose.Status) + } + log.Printf("successful") + return nil +} + +func runDelete() error { + c, err := newClient(user) + if err != nil { + return err + } + + r, err := c.Send(&pb.Request{Inner: &pb.Request_Delete{Delete: &pb.DeleteRequest{}}}) + if err != nil { + return err + } + log.Print(r) + return nil +} + +func randBytes(count int) []byte { + bs := make([]byte, count) + if _, err := rand.Read(bs); err != nil { + log.Fatalf("rand: %v", err) + } + return bs +} diff --git a/host/config/config.go b/host/config/config.go new file mode 100644 index 0000000..165ba63 --- /dev/null +++ b/host/config/config.go @@ -0,0 +1,138 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package config + +import ( + "fmt" + "os" + "runtime" + "strings" + "time" + + "github.com/signalapp/svr2/util" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "gopkg.in/yaml.v2" +) + +type Config struct { + // See zap.Config + Log *zap.Config `yaml:"log"` + // Address for the peer server to listen on (ex 10.0.0.1:1234) + PeerAddr string `yaml:"peerAddr"` + // Address for http client server to listen on + ClientListenAddr string `yaml:"clientListenAddr"` + // Address for http control server to listen on + ControlListenAddr string `yaml:"controlListenAddr"` + // Configuration for redis cluster + Redis RedisConfig `yaml:"redis"` + // HTTP endpoint rate limits + Limit RateLimitConfig `yaml:"limit"` + // Host specific Raft configuration + Raft RaftHostConfig `yaml:"raft"` + // Peer protocol configuration + Peer PeerConfig `yaml:"peer"` + // The MRENCLAVE this host serves + EnclaveID string `yaml:"enclaveId"` + // Address to reach a datadog compatible statsd + DatadogAgentHost string `yaml:"datadogAgentHost"` + // TTL of initial Redis peerdb entry. + InitialRedisPeerDBTTL time.Duration `yaml:"initialRedisPeerDBTTL"` + // TTL of recurring Redis peerdb entry. + RecurringRedisPeerDBTTL time.Duration `yaml:"recurringRedisPeerDBTTL"` + // Configuration for the client websocket handler + Request RequestConfig `yaml:"request"` +} + +// validate returns a list of validation errors, or empty if there are no errors. +type validator interface{ validate() []string } + +func (c *Config) validate() error { + validators := []validator{&c.Raft, &c.Redis, &c.Limit, &c.Request} + var errs []string + for _, validator := range validators { + errs = append(errs, validator.validate()...) + } + if len(errs) != 0 { + return fmt.Errorf("invalid config: %v", strings.Join(errs, ",")) + } + return nil +} + +// Read parses the yaml file at the provided path into a Config +func Read(path string) (*Config, error) { + bs, err := os.ReadFile(path) + if err != nil { + return nil, err + } + withenv := []byte(os.ExpandEnv(string(bs))) + c, err := unmarshal(withenv) + if err != nil { + return nil, err + } + if err := c.validate(); err != nil { + return nil, err + } + return c, nil +} + +func unmarshal(bs []byte) (*Config, error) { + cfg := Default() + if err := yaml.Unmarshal(bs, cfg); err != nil { + return nil, err + } + return cfg, nil +} + +// Default provides reasonable default parameters that may be overridden by a config file +func Default() *Config { + encoderConfig := zap.NewProductionEncoderConfig() + encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder + config := zap.Config{ + Level: zap.NewAtomicLevelAt(zap.DebugLevel), + Development: true, + Encoding: "console", + EncoderConfig: encoderConfig, + OutputPaths: []string{"stderr"}, + ErrorOutputPaths: []string{"stderr"}, + DisableStacktrace: true, + } + + return &Config{ + Log: &config, + PeerAddr: "localhost:9000", + ClientListenAddr: "localhost:8080", + ControlListenAddr: "localhost:8081", + Raft: RaftHostConfig{ + RefreshStatusDuration: time.Minute, + TickDuration: time.Second, + MetricPollDuration: time.Second * 10, + RefreshAttestationDuration: time.Minute * 10, + EnclaveConcurrency: util.Min(runtime.NumCPU(), 64), + }, + Redis: RedisConfig{ + Name: "test", + MinSleepDuration: time.Second, + MaxSleepDuration: time.Second * 30, + }, + Limit: RateLimitConfig{ + BucketSize: 10, + LeakRateScalar: 10, + LeakRateDuration: time.Minute, + }, + Peer: PeerConfig{ + MinSleepDuration: time.Millisecond * 10, + MaxSleepDuration: time.Minute, + AbandonDuration: time.Minute * 10, + BufferSize: 10_000, + }, + Request: RequestConfig{ + WebsocketHandshakeTimeout: time.Second * 30, + SocketTimeout: time.Second * 30, + }, + EnclaveID: "enclave", + InitialRedisPeerDBTTL: time.Minute * 120, + RecurringRedisPeerDBTTL: time.Minute * 5, + } +} diff --git a/host/config/config_test.go b/host/config/config_test.go new file mode 100644 index 0000000..1b0cf07 --- /dev/null +++ b/host/config/config_test.go @@ -0,0 +1,38 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package config + +import ( + "testing" + "time" + + "go.uber.org/zap" +) + +func TestConfig(t *testing.T) { + var yaml = ` +log: + level: info +raft: + tickDuration: 1000ms + metricPollDuration: 2h +` + conf, err := unmarshal([]byte(yaml)) + if err != nil { + t.Fatal(err) + } + if conf.Log.Level.Level() != zap.InfoLevel { + t.Errorf("conf.level=%v, want %v", conf.Log.Level.Level(), zap.InfoLevel) + } + if conf.Log.Encoding != "console" { + t.Errorf("conf.encoding=%v, want %v", conf.Log.Encoding, "console") + } + if conf.Raft.TickDuration != time.Second { + t.Errorf("conf.raft.tickDuration=%v, want %v", conf.Raft.TickDuration, time.Second) + } + if conf.Raft.MetricPollDuration != 2*time.Hour { + t.Errorf("conf.raft.metricPollDuration=%v, want %v", conf.Raft.MetricPollDuration, time.Hour*2) + } + +} diff --git a/host/config/peer.go b/host/config/peer.go new file mode 100644 index 0000000..58dc03b --- /dev/null +++ b/host/config/peer.go @@ -0,0 +1,31 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package config + +import ( + "fmt" + "time" +) + +type PeerConfig struct { + // minimum time to sleep for exponential backoff retries to connect to a peer + MinSleepDuration time.Duration `yaml:"minSleepDuration"` + // maximum time to sleep for exponential backoff retries to connect to a peer + MaxSleepDuration time.Duration `yaml:"maxSleepDuration"` + // maximum time to attempt to connect to a peer before giving up + AbandonDuration time.Duration `yaml:"abandonDuration"` + // maximum number of messages to buffer for sending to a peer + BufferSize int `yaml:"bufferSize"` +} + +func (p *PeerConfig) validate() []string { + var errs []string + if p.BufferSize < 1 { + errs = append(errs, fmt.Sprintf("invalid BufferSize: %v", p.BufferSize)) + } + if p.MinSleepDuration > p.MaxSleepDuration { + errs = append(errs, fmt.Sprintf("MinSleep (%v) must be less than MaxSleep (%v)", p.MinSleepDuration, p.MaxSleepDuration)) + } + return errs +} diff --git a/host/config/raft.go b/host/config/raft.go new file mode 100644 index 0000000..c2a0bb2 --- /dev/null +++ b/host/config/raft.go @@ -0,0 +1,36 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package config + +import ( + "fmt" + "time" +) + +type RaftHostConfig struct { + // how often to update the peerdb to let other peers know we're joinable + RefreshStatusDuration time.Duration `yaml:"refreshStatusDuration"` + // how often to fetch a fresh attestation in the enclave + RefreshAttestationDuration time.Duration `yaml:"refreshAttestationDuration"` + // how often to send a raft tick down to the enclave + TickDuration time.Duration `yaml:"tickDuration"` + // how often to poll metrics from the enclave + MetricPollDuration time.Duration `yaml:"metricPollDuration"` + // max number of in-flight enclave calls + EnclaveConcurrency int `yaml:"enclaveConcurrency"` +} + +func (r *RaftHostConfig) validate() []string { + var errs []string + if r.EnclaveConcurrency <= 1 { + errs = append(errs, fmt.Sprintf("invalid EnclaveConcurrency: %v", r.EnclaveConcurrency)) + } + if r.TickDuration <= 0 { + errs = append(errs, fmt.Sprintf("invalid TickDuration: %v", r.TickDuration)) + } + if r.RefreshAttestationDuration <= 0 { + errs = append(errs, fmt.Sprintf("invalid RefreshAttestationDuration: %v", r.RefreshAttestationDuration)) + } + return errs +} diff --git a/host/config/rate.go b/host/config/rate.go new file mode 100644 index 0000000..89e1046 --- /dev/null +++ b/host/config/rate.go @@ -0,0 +1,29 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package config + +import ( + "fmt" + "time" +) + +type RateLimitConfig struct { + // The maximum size of the leaky bucket. This is the maximum "burst" of requests that will be allowed + BucketSize int `yaml:"bucketSize"` + // The amount of requests that will be added (up to BucketSize) per LeakRateDuration + LeakRateScalar int `yaml:"leakRateScalar"` + // The period at which LeakRateScalar additional requests will be allowed + LeakRateDuration time.Duration `yaml:"leakRateDuration"` +} + +func (r *RateLimitConfig) validate() []string { + var errs []string + if r.BucketSize < 0 { + errs = append(errs, fmt.Sprintf("invalid BucketSize: %v", r.BucketSize)) + } + if r.LeakRateScalar < 0 { + errs = append(errs, fmt.Sprintf("invalid LeakRateDuration: %v", r.LeakRateScalar)) + } + return errs +} diff --git a/host/config/redis.go b/host/config/redis.go new file mode 100644 index 0000000..c31eadc --- /dev/null +++ b/host/config/redis.go @@ -0,0 +1,37 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package config + +import ( + "fmt" + "strings" + "time" +) + +type RedisConfig struct { + // A seed list of host:port addresses of cluster nodes. + Addrs []string `yaml:"addrs"` + // password for instance (may be blank if protected mode is disabled) + Password string `yaml:"password"` + // a unique name for the deployment + Name string `yaml:"name"` + // minimum time to sleep for exponential backoff retries to redis + MinSleepDuration time.Duration `yaml:"minSleepDuration"` + // maximum time to sleep for exponential backoff retries to redis + MaxSleepDuration time.Duration `yaml:"maxSleepDuration"` +} + +func (r *RedisConfig) validate() []string { + var errs []string + if len(r.Addrs) == 0 { + errs = append(errs, fmt.Sprintf("must provide redis Addrs")) + } + for _, addr := range r.Addrs { + spl := strings.Split(addr, ":") + if len(spl) != 2 { + errs = append(errs, fmt.Sprintf("invalid redis Addr %v", addr)) + } + } + return errs +} diff --git a/host/config/request.go b/host/config/request.go new file mode 100644 index 0000000..dff74d1 --- /dev/null +++ b/host/config/request.go @@ -0,0 +1,28 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package config + +import ( + "fmt" + "time" +) + +type RequestConfig struct { + // Timeout to perform websocket handshake over http connection + WebsocketHandshakeTimeout time.Duration `yaml:"socketTimeout"` + + // Timeout for websocket read/write operations + SocketTimeout time.Duration `yaml:"socketTimeout"` +} + +func (r *RequestConfig) validate() []string { + var errs []string + if r.WebsocketHandshakeTimeout <= 0 { + errs = append(errs, fmt.Sprintf("Handshake timeout %v must be >0", r.WebsocketHandshakeTimeout)) + } + if r.SocketTimeout <= 0 { + errs = append(errs, fmt.Sprintf("Socket timeout %v must be >0", r.SocketTimeout)) + } + return errs +} diff --git a/host/dispatch/dispatcher.go b/host/dispatch/dispatcher.go new file mode 100644 index 0000000..c175ca2 --- /dev/null +++ b/host/dispatch/dispatcher.go @@ -0,0 +1,350 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package dispatch + +import ( + "context" + "errors" + "fmt" + "strconv" + "sync" + "time" + + "github.com/signalapp/svr2/config" + "github.com/signalapp/svr2/logger" + "github.com/signalapp/svr2/peer" + "github.com/signalapp/svr2/util" + "golang.org/x/sync/errgroup" + + metrics "github.com/armon/go-metrics" + pb "github.com/signalapp/svr2/proto" +) + +type txid uint64 + +type EnclaveSender interface { + // SendMessage sends a message to the running enclave. + SendMessage(msgPB *pb.UntrustedMessage) error +} + +type PeerSender interface { + // Send sends a message to a remote peer + Send(msg *pb.PeerMessage) error +} + +// Dispatcher routes messages between host and +// enclave, associating requests and replies +type Dispatcher struct { + enclave EnclaveSender + // generate unique request ids + txGen *util.TxGenerator + // messages to send to enclave + tx chan *toEnclave + // messages received from enclave + rx <-chan *pb.EnclaveMessage + // updated with enclave metrics + metricsUpdater *metricsUpdater + // receivers represents requests from the host waiting for a reply + receiversMu sync.Mutex + receivers map[txid]chan *pb.EnclaveMessage + // configuration + cfg config.RaftHostConfig +} + +var ( + ErrEnclaveClosed = errors.New("server: Enclave closed") + + tickCounterName = []string{"dispatcher", "ticks"} + processMessageCounterName = []string{"dispatcher", "processMessage"} + peerReconnectCounterName = []string{"dispatcher", "manualPeerReconnect"} + refreshAttestationCounterName = []string{"dispatcher", "refreshAttestation"} +) + +type toEnclave struct { + // message to send to enclave + message *pb.UntrustedMessage + // channel on which to receive a reply + recv chan *pb.EnclaveMessage +} + +// New creates a dispatcher which sends messages to +// the provided enclave, and receives messages from the enclave on rx +func New(cfg config.RaftHostConfig, txGen *util.TxGenerator, e EnclaveSender, rx <-chan *pb.EnclaveMessage) *Dispatcher { + return &Dispatcher{ + enclave: e, + txGen: txGen, + receivers: make(map[txid]chan *pb.EnclaveMessage), + tx: make(chan *toEnclave), + rx: rx, + metricsUpdater: newMetricsUpdater(), + cfg: cfg, + } +} + +// Send sends a message to the enclave and potentially wait for a reply. If p is a message +// that requires no reply, Send will still block until the enclave has processed the message +// and this method will return nil. +func (d *Dispatcher) Send(p *pb.UntrustedMessage) (*pb.EnclaveMessage, error) { + recv := make(chan *pb.EnclaveMessage, 1) + d.tx <- &toEnclave{p, recv} + response := <-recv + if _, expectReply := p.Inner.(*pb.UntrustedMessage_H2ERequest); response == nil && expectReply { + // we expected a response but got a closed channel + return nil, errors.New("failed to get enclave response") + } + return response, nil +} + +// SendTransaction is like [Send] but for requests to the enclave that expect a response +// The provided request should not be tagged with a RequestID, this will be handled internally +func (d *Dispatcher) SendTransaction(req *pb.HostToEnclaveRequest) (*pb.HostToEnclaveResponse, error) { + if req.RequestId != 0 { + return nil, errors.New("illegal SendTransaction : should not provide a request id") + } + req.RequestId = d.txGen.NextID() + wrappedResp, err := d.Send(&pb.UntrustedMessage{Inner: &pb.UntrustedMessage_H2ERequest{ + H2ERequest: req, + }}) + if err != nil { + return nil, err + } + reply, ok := wrappedResp.Inner.(*pb.EnclaveMessage_H2EResponse) + if !ok { + return nil, errors.New("unexpected response type from enclave") + } + return reply.H2EResponse, nil +} + +// Run runs the dispatcher process until cancelled or encountering a fatal error +func (d *Dispatcher) Run(ctx context.Context, peerSender PeerSender) error { + grp, ctx := errgroup.WithContext(ctx) + grp.Go(func() error { return d.forwardToEnclaveLoop(ctx) }) + grp.Go(func() error { return d.forwardToHostLoop(ctx, peerSender) }) + grp.Go(func() error { return d.tickLoop(ctx) }) + grp.Go(func() error { return d.metricLoop(ctx) }) + grp.Go(func() error { return d.refreshAttestationLoop(ctx) }) + err := grp.Wait() + + return err +} + +// forwardToEnclaveLoop takes messages receieved via Send and forwards them to the enclave +func (d *Dispatcher) forwardToEnclaveLoop(ctx context.Context) error { + eg, ctx := errgroup.WithContext(ctx) + + // Rather than let it compete with host originated message sends, one of our concurrency + // permits is reserved for our tick thread. This ensures ticks run in a timely fashion. + enclaveConcurrency := d.cfg.EnclaveConcurrency - 1 + + for i := 0; i < enclaveConcurrency; i++ { + eg.Go(func() error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case toEnclave := <-d.tx: + if err := d.forwardToEnclave(toEnclave); err != nil { + logger.Errorw("failed to send message to enclave", "err", err) + } + } + } + }) + } + return eg.Wait() +} + +// forwardToEnclaveLoop takes messages receieved from the enclave and forwards them to the host +func (d *Dispatcher) forwardToHostLoop(ctx context.Context, peerSender PeerSender) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case fromEnclave := <-d.rx: + if fromEnclave == nil { + return ErrEnclaveClosed + } + if err := d.forwardToHost(peerSender, fromEnclave); err != nil { + logger.Errorw("dropping enclave message", "err", err) + } + } + } +} + +// tickLoop sends tick messages to the enclave on a fixed interval +func (d *Dispatcher) tickLoop(ctx context.Context) error { + ticker := time.NewTicker(d.cfg.TickDuration) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + tick := pb.UntrustedMessage{Inner: &pb.UntrustedMessage_TimerTick{TimerTick: &pb.TimerTick{NewTimestampUnixSecs: uint64(time.Now().Unix())}}} + if err := d.enclave.SendMessage(&tick); err != nil { + return err + } + metrics.IncrCounter(tickCounterName, 1) + } + } +} + +// refreshAttestationLoop sends RefreshAttestation messages to the enclave on a fixed interval +func (d *Dispatcher) refreshAttestationLoop(ctx context.Context) error { + ticker := time.NewTicker(d.cfg.RefreshAttestationDuration) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + if err := d.refreshAttestation(); err != nil { + return err + } + } + } +} + +func (d *Dispatcher) refreshAttestation() error { + resp, err := d.SendTransaction(&pb.HostToEnclaveRequest{ + Inner: &pb.HostToEnclaveRequest_RefreshAttestation{ + RefreshAttestation: &pb.RefreshAttestation{RotateClientKey: true}, + }, + }) + if err != nil { + return err + } + v, ok := resp.Inner.(*pb.HostToEnclaveResponse_Status) + if !ok { + return fmt.Errorf("unexpected response from enclave %v", resp) + } + if v.Status != pb.Error_OK { + logger.Warnw("failed to refresh attestation", "err", v.Status) + } + metrics.IncrCounterWithLabels(refreshAttestationCounterName, 1, []metrics.Label{ + {Name: "success", Value: strconv.FormatBool(v.Status == pb.Error_OK)}, + }) + return nil +} + +// metricLoop sends requests for metrics to the enclave on a fixed interval +func (d *Dispatcher) metricLoop(ctx context.Context) error { + poller := time.NewTicker(d.cfg.MetricPollDuration) + defer poller.Stop() + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-poller.C: + // First get status, since it might reset base labels. + resp, err := d.SendTransaction(&pb.HostToEnclaveRequest{ + Inner: &pb.HostToEnclaveRequest_GetEnclaveStatus{GetEnclaveStatus: true}, + }) + if err != nil { + return err + } + switch v := resp.Inner.(type) { + case *pb.HostToEnclaveResponse_GetEnclaveStatusReply: + d.metricsUpdater.updateStatus(v.GetEnclaveStatusReply) + case *pb.HostToEnclaveResponse_Status: + logger.Warnf("failed to poll status from enclave", "err", v.Status.String()) + default: + return errors.New("unexpected HostToEnclaveResponse from enclave") + } + // Then, get actual metrics. + resp, err = d.SendTransaction(&pb.HostToEnclaveRequest{ + Inner: &pb.HostToEnclaveRequest_RequestMetrics{RequestMetrics: true}, + }) + if err != nil { + return err + } + switch v := resp.Inner.(type) { + case *pb.HostToEnclaveResponse_MetricsReply: + d.metricsUpdater.updateMetrics(v.MetricsReply) + case *pb.HostToEnclaveResponse_Status: + logger.Warnf("failed to poll metrics from enclave", "err", v.Status.String()) + default: + return errors.New("unexpected HostToEnclaveResponse from enclave") + } + } + } +} + +func (d *Dispatcher) forwardToEnclave(toEnclave *toEnclave) error { + metrics.IncrCounterWithLabels(processMessageCounterName, 1, []metrics.Label{{Name: "destination", Value: "enclave"}}) + p := toEnclave.message + request, expectReply := p.Inner.(*pb.UntrustedMessage_H2ERequest) + + if !expectReply { + // host does not expect a reply, fire and forget + err := d.enclave.SendMessage(toEnclave.message) + close(toEnclave.recv) + return err + } + + d.setReceiver(txid(request.H2ERequest.RequestId), toEnclave.recv) + if err := d.enclave.SendMessage(toEnclave.message); err != nil { + d.deleteReceiver(txid(request.H2ERequest.RequestId)) + close(toEnclave.recv) + return err + } + return nil +} + +func (d *Dispatcher) forwardToHost(peerSender PeerSender, message *pb.EnclaveMessage) error { + switch v := message.Inner.(type) { + case *pb.EnclaveMessage_H2EResponse: + metrics.IncrCounterWithLabels(processMessageCounterName, 1, []metrics.Label{{Name: "destination", Value: "host"}}) + id := txid(v.H2EResponse.RequestId) + recv, ok := d.getReceiver(id) + if !ok { + return fmt.Errorf("response %v has no associated request", message) + } + recv <- message + d.deleteReceiver(id) + close(recv) + case *pb.EnclaveMessage_PeerMessage: + metrics.IncrCounterWithLabels(processMessageCounterName, 1, []metrics.Label{{Name: "destination", Value: "peer"}}) + if err := peerSender.Send(v.PeerMessage); errors.Is(err, peer.ErrResetConnection) { + // The peerSender is full, we should attempt to reconnect so we can drop some messages + err := d.resetPeerConnection(v.PeerMessage.PeerId) + if err != nil { + logger.Errorw("failed to reset peer connection", "peerID", v.PeerMessage.PeerId, "err", err) + } + metrics.IncrCounterWithLabels(peerReconnectCounterName, 1, []metrics.Label{{Name: "success", Value: strconv.FormatBool(err != nil)}}) + } else if err != nil { + return err + } + } + return nil +} + +func (d *Dispatcher) resetPeerConnection(peerID []byte) error { + _, err := d.Send(&pb.UntrustedMessage{ + Inner: &pb.UntrustedMessage_ResetPeer{ + ResetPeer: &pb.EnclavePeer{ + PeerId: peerID[:], + }, + }, + }) + return err +} + +func (d *Dispatcher) getReceiver(id txid) (recv chan *pb.EnclaveMessage, exists bool) { + d.receiversMu.Lock() + defer d.receiversMu.Unlock() + recv, exists = d.receivers[id] + return +} + +func (d *Dispatcher) setReceiver(id txid, recv chan *pb.EnclaveMessage) { + d.receiversMu.Lock() + defer d.receiversMu.Unlock() + d.receivers[id] = recv +} + +func (d *Dispatcher) deleteReceiver(id txid) { + d.receiversMu.Lock() + defer d.receiversMu.Unlock() + delete(d.receivers, id) +} diff --git a/host/dispatch/dispatcher_test.go b/host/dispatch/dispatcher_test.go new file mode 100644 index 0000000..20bb0d4 --- /dev/null +++ b/host/dispatch/dispatcher_test.go @@ -0,0 +1,268 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package dispatch + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/signalapp/svr2/config" + "github.com/signalapp/svr2/peer" + pb "github.com/signalapp/svr2/proto" + "github.com/signalapp/svr2/util" + "google.golang.org/protobuf/proto" +) + +type mockEnclave struct { + ech chan *pb.EnclaveMessage // messages from the 'enclave' to the dispatcher + uch chan *pb.UntrustedMessage // messages from the host to the 'enclave' + + // can be used to delay requests arbitrarily + requestIdx *atomic.Int32 + blocks map[int32]chan struct{} +} + +func (m *mockEnclave) SendMessage(p *pb.UntrustedMessage) error { + index := m.requestIdx.Add(1) - 1 + if c, exists := m.blocks[index]; exists { + <-c + } + m.uch <- p + return nil +} + +func (m *mockEnclave) close() { + close(m.ech) + close(m.uch) +} + +type fixture struct { + m *mockEnclave + d *Dispatcher +} + +type mockPeerSender struct{} + +func (*mockPeerSender) Send(*pb.PeerMessage) error { return nil } + +func makeFixture() fixture { + return makeFixtureWithSender(&mockPeerSender{}) +} + +func makeFixtureWithSender(peerSender PeerSender) fixture { + m := mockEnclave{ + make(chan *pb.EnclaveMessage), + make(chan *pb.UntrustedMessage), + &atomic.Int32{}, + make(map[int32]chan struct{}), + } + dispatcher := New(config.RaftHostConfig{ + TickDuration: time.Second, + EnclaveConcurrency: 3, + }, &util.TxGenerator{}, &m, m.ech) + go dispatcher.forwardToEnclaveLoop(context.Background()) + go dispatcher.forwardToHostLoop(context.Background(), peerSender) + return fixture{&m, dispatcher} +} + +func (f *fixture) Close() { f.m.close() } +func (f *fixture) hostSend(p *pb.UntrustedMessage) chan *pb.EnclaveMessage { + recv := make(chan *pb.EnclaveMessage) + go func() { + defer close(recv) + resp, _ := f.d.Send(p) + recv <- resp + }() + // wait for the "enclave" to receive the request + <-f.m.uch + return recv +} +func (f *fixture) enclaveSend(p *pb.EnclaveMessage) { + f.m.ech <- p +} + +func TestRequestResponse(t *testing.T) { + f := makeFixture() + defer f.Close() + + // send a request and return a reply through the enclave + recv := f.hostSend(untrustedReq(7)) + f.enclaveSend(enclaveReply(7)) + v, ok := (<-recv).Inner.(*pb.EnclaveMessage_H2EResponse) + + if !ok { + t.Errorf("Send() received no response") + } + if v.H2EResponse.RequestId != 7 { + t.Errorf("Send() response = %v, want %v", v.H2EResponse.RequestId, 7) + } +} + +func TestUnorderedResponses(t *testing.T) { + f := makeFixture() + defer f.Close() + + // send 2 requests + recv1 := f.hostSend(untrustedReq(7)) + recv2 := f.hostSend(untrustedReq(8)) + + // reply to second message + f.enclaveSend(enclaveReply(8)) + + select { + case <-recv1: + t.Errorf("Expected reply(8), got reply(7)") + case r := <-recv2: + v, ok := r.Inner.(*pb.EnclaveMessage_H2EResponse) + if !ok { + t.Errorf("Send() received no response") + } + if v.H2EResponse.RequestId != 8 { + t.Errorf("Send() response = %v, want %v", v.H2EResponse.RequestId, 8) + } + } + + // reply to first message + f.enclaveSend(enclaveReply(7)) + v, ok := (<-recv1).Inner.(*pb.EnclaveMessage_H2EResponse) + + if !ok { + t.Errorf("Send() received no response") + } + if v.H2EResponse.RequestId != 7 { + t.Errorf("Send() response = %v, want %v", v.H2EResponse.RequestId, 7) + } +} + +func TestNoReply(t *testing.T) { + f := makeFixture() + defer f.Close() + recv1 := f.hostSend(untrustedCommand()) + resp := <-recv1 + if resp != nil { + t.Errorf("Send() = %v, want = %v, send of reply should immediately finish", resp, nil) + } +} + +type slowPeerSender struct { + block bool + out chan *pb.PeerMessage +} + +func (s *slowPeerSender) Send(m *pb.PeerMessage) error { + if s.block { + return peer.ErrResetConnection + } + s.out <- m + return nil +} + +func TestReconnectPeer(t *testing.T) { + sender := &slowPeerSender{out: make(chan *pb.PeerMessage, 1)} + f := makeFixtureWithSender(sender) + defer f.Close() + + peerMessage := &pb.EnclaveMessage{Inner: &pb.EnclaveMessage_PeerMessage{ + PeerMessage: &pb.PeerMessage{ + PeerId: []byte{}, + Inner: &pb.PeerMessage_Data{Data: []byte{}}, + }, + }} + + // return an error + sender.block = true + f.enclaveSend(peerMessage) + + // the dispatcher should send a ResetPeer to the enclave + msg := <-f.m.uch + if msg.GetResetPeer() == nil { + t.Errorf("expected ResetPeer, got %v", msg) + } + + // restore sender + sender.block = false + f.enclaveSend(peerMessage) + if got := <-sender.out; !proto.Equal(got, peerMessage.GetPeerMessage()) { + t.Errorf("dispatcher forwarded %v, want %v", got, peerMessage) + } +} + +func TestConcurrentRequests(t *testing.T) { + f := makeFixture() + defer f.Close() + + // block the first three enclave requests + for i := int32(0); i < 3; i++ { + f.m.blocks[i] = make(chan struct{}) + } + + finished := make(chan struct{}, 6) + for i := 0; i < 6; i++ { + go func() { + f.d.Send(untrustedCommand()) + finished <- struct{}{} + }() + } + if !channelEmpty(f.m.uch) { + t.Fatal("expected no requests to make it to enclave") + } + + // unblock 0 and 1 + f.m.blocks[0] <- struct{}{} + f.m.blocks[1] <- struct{}{} + + // have 1 free permit, and all other requests are unblocked, + // should be able to process all but requestId 2 + for i := 0; i < 5; i++ { + <-f.m.uch + } + + if !channelEmpty(f.m.uch) { + t.Fatal("requestId 3 should be blocked") + } + f.m.blocks[2] <- struct{}{} + <-f.m.uch + + for i := 0; i < 6; i++ { + <-finished + } + +} + +func channelEmpty[T any](ch chan T) bool { + select { + case <-ch: + return false + default: + return true + } +} + +func untrustedReq(id uint64) *pb.UntrustedMessage { + return &pb.UntrustedMessage{ + Inner: &pb.UntrustedMessage_H2ERequest{ + H2ERequest: &pb.HostToEnclaveRequest{ + RequestId: id, + }, + }, + } +} + +// doesn't require a response +func untrustedCommand() *pb.UntrustedMessage { + return &pb.UntrustedMessage{Inner: &pb.UntrustedMessage_TimerTick{TimerTick: &pb.TimerTick{NewTimestampUnixSecs: uint64(time.Now().Unix())}}} +} + +func enclaveReply(id uint64) *pb.EnclaveMessage { + return &pb.EnclaveMessage{ + Inner: &pb.EnclaveMessage_H2EResponse{ + H2EResponse: &pb.HostToEnclaveResponse{ + RequestId: id, + Inner: &pb.HostToEnclaveResponse_Status{Status: pb.Error_OK}, + }, + }, + } +} diff --git a/host/dispatch/metrics.go b/host/dispatch/metrics.go new file mode 100644 index 0000000..b8e1827 --- /dev/null +++ b/host/dispatch/metrics.go @@ -0,0 +1,133 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package dispatch + +import ( + "github.com/google/go-cmp/cmp" + "github.com/signalapp/svr2/peerid" + + metrics "github.com/armon/go-metrics" + pb "github.com/signalapp/svr2/proto" +) + +type metricsWriter interface { + IncrCounterWithLabels(key []string, val float32, labels []metrics.Label) + SetGaugeWithLabels(key []string, val float32, labels []metrics.Label) +} + +// metricsUpdater converts polled enclave metric protobufs into metrics updates +type metricsUpdater struct { + writer metricsWriter + counters map[string][]*pb.U64PB + baseLabels []metrics.Label +} + +func newMetricsUpdater() *metricsUpdater { + return &metricsUpdater{ + writer: metrics.Default(), + counters: make(map[string][]*pb.U64PB), + baseLabels: make([]metrics.Label, 0, 16), + } +} + +// findPrevCounter returns the last value seen for the provided counter +func (m *metricsUpdater) findPrevCounter(counter *pb.U64PB) *pb.U64PB { + prev, exists := m.counters[counter.Name] + if !exists { + return nil + } + for _, c := range prev { + if cmp.Equal(c.Tags, counter.Tags) { + return c + } + } + return nil +} + +// updateMetrics updates the underlying metrics using the current enclave metrics snapshot. Because +// enclave counters are the counter's current state, the previous snapshot is stored and the diff +// between the current snapshot and the previous snapshot is calculated. +func (m *metricsUpdater) updateMetrics(metricsPB *pb.MetricsPB) { + for _, gauge := range metricsPB.Gauges { + m.writer.SetGaugeWithLabels([]string{gauge.Name}, float32(gauge.V), m.toLabels(gauge)) + } + for _, counter := range metricsPB.Counters { + prev := m.findPrevCounter(counter) + if prev == nil { + m.writer.IncrCounterWithLabels([]string{counter.Name}, float32(counter.V), m.toLabels(counter)) + } else { + diff := int64(counter.V - prev.V) + m.writer.IncrCounterWithLabels([]string{counter.Name}, float32(diff), m.toLabels(counter)) + } + } + + for name, counters := range m.counters { + m.counters[name] = counters[:0] + } + for _, counter := range metricsPB.Counters { + m.counters[counter.Name] = append(m.counters[counter.Name], counter) + } +} + +var ( + peerState = []string{"peer", "state"} + peerAttestationTs = []string{"peer", "last_attestation_unix_secs"} + peerNextIdx = []string{"peer", "next_idx"} + peerMatchIdx = []string{"peer", "match_idx"} + peerInflightIdx = []string{"peer", "inflight_idx"} +) + +func raftStatus(s *pb.EnclavePeerStatus) string { + switch { + case s.IsLeader: + return "leader" + case s.IsVoting: + return "voter" + case s.InRaft: + return "nonvoter" + default: + return "nonmember" + } +} + +func (m *metricsUpdater) updateStatus(s *pb.EnclaveReplicaStatus) { + for _, peer := range s.Peers { + if peer.Me { + m.baseLabels = m.baseLabels[:0] + m.baseLabels = append(m.baseLabels, metrics.Label{Name: "raft", Value: raftStatus(peer)}) + if id, err := peerid.Make(peer.PeerId); err == nil { + m.baseLabels = append(m.baseLabels, metrics.Label{Name: "myid", Value: id.String()}) + } + break + } + } + for _, peer := range s.Peers { + id, err := peerid.Make(peer.PeerId) + if err != nil || peer.Me { + continue + } + lbls := append(m.baseLabels, metrics.Label{Name: "peerid", Value: id.String()}) + m.writer.SetGaugeWithLabels(peerState, float32(peer.GetConnectionStatus().GetState()), lbls) + if peer.InRaft { + m.writer.SetGaugeWithLabels(peerAttestationTs, float32(peer.GetConnectionStatus().GetLastAttestationUnixSecs()), lbls) + if peer.ReplicationStatus != nil { + m.writer.SetGaugeWithLabels(peerNextIdx, float32(peer.ReplicationStatus.NextIndex), lbls) + m.writer.SetGaugeWithLabels(peerMatchIdx, float32(peer.ReplicationStatus.MatchIndex), lbls) + m.writer.SetGaugeWithLabels(peerInflightIdx, float32(peer.ReplicationStatus.InflightIndex), lbls) + } + } + } +} + +// toLabels extracts metrics.Labels from tags on a metrics proto +func (m *metricsUpdater) toLabels(metric *pb.U64PB) []metrics.Label { + labels := m.baseLabels + for k, v := range metric.Tags { + labels = append(labels, metrics.Label{ + Name: k, + Value: v, + }) + } + return labels +} diff --git a/host/dispatch/metrics_test.go b/host/dispatch/metrics_test.go new file mode 100644 index 0000000..211c678 --- /dev/null +++ b/host/dispatch/metrics_test.go @@ -0,0 +1,132 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package dispatch + +import ( + "fmt" + "math" + "strings" + "testing" + + pb "github.com/signalapp/svr2/proto" + + metrics "github.com/armon/go-metrics" +) + +type mockMetricsWriter struct { + data map[string]float32 +} + +func flatkey(key []string, labels []metrics.Label) string { + s := strings.Join(key, ".") + for _, label := range labels { + s += fmt.Sprintf("%v:%v", label.Name, label.Value) + } + return s +} + +func (m *mockMetricsWriter) IncrCounterWithLabels(key []string, val float32, labels []metrics.Label) { + k := flatkey(key, labels) + if v, exists := m.data[k]; !exists { + m.data[k] = val + } else { + m.data[k] = v + val + } +} + +func (m *mockMetricsWriter) SetGaugeWithLabels(key []string, val float32, labels []metrics.Label) { + k := flatkey(key, labels) + m.data[k] = val +} + +func TestFindMetrics(t *testing.T) { + m := metricsUpdater{ + writer: &mockMetricsWriter{data: make(map[string]float32)}, + counters: make(map[string][]*pb.U64PB), + } + initial := []*pb.U64PB{ + {Name: "m1", Tags: map[string]string{"t1": "1"}, V: 1.0}, + {Name: "m1", Tags: map[string]string{"t1": "2"}, V: 2.0}, + } + m.updateMetrics(&pb.MetricsPB{Counters: initial}) + + prev := m.findPrevCounter(&pb.U64PB{Name: "m1", Tags: map[string]string{"t1": "1"}, V: 3.0}) + if prev == nil { + t.Errorf("previous counter not found") + } + if prev != initial[0] { + t.Errorf("findPrevCounter=%v, want %v", prev, initial[0]) + } + + prev = m.findPrevCounter(&pb.U64PB{Name: "m2", Tags: map[string]string{"t1": "1"}, V: 1.0}) + if prev != nil { + t.Error("previous counter should not exist") + } +} + +func TestUpdateCounters(t *testing.T) { + measurements := []*pb.U64PB{ + {Name: "c1", Tags: map[string]string{"t1": "1"}, V: 1.0}, + {Name: "c1", Tags: map[string]string{"t1": "2"}, V: 2.0}, + {Name: "c2", Tags: map[string]string{"t1": "1"}, V: 1.0}, + } + + type update struct { + id *pb.U64PB + v uint64 + } + + tests := []struct { + name string + update update + expected float32 + }{ + {"Existing_c1_t1:1", update{measurements[0], 2}, 2.0}, + {"Existing_c1_t1:2", update{measurements[1], 3}, 3.0}, + {"Existing_c2", update{measurements[2], 2}, 2.0}, + {"ExistingDecrement", update{measurements[1], 1}, 1.0}, + {"NewName", update{&pb.U64PB{Name: "c3"}, 5}, 5.0}, + {"NewTagValue", update{&pb.U64PB{Name: "c1", Tags: map[string]string{"t2": "3"}}, 5}, 5.0}, + {"NewTag", update{&pb.U64PB{Name: "c1", Tags: map[string]string{"t3": "3"}}, 5}, 5.0}, + } + + for _, tt := range tests { + for _, typ := range []string{"counter", "gauge"} { + + name := fmt.Sprintf("%v_%v", typ, tt.name) + + t.Run(name, func(t *testing.T) { + w := &mockMetricsWriter{data: make(map[string]float32)} + m := metricsUpdater{ + writer: w, + counters: make(map[string][]*pb.U64PB), + } + + // update with initial values, then the test update + switch typ { + case "counter": + m.updateMetrics(&pb.MetricsPB{Counters: measurements}) + m.updateMetrics(&pb.MetricsPB{Counters: []*pb.U64PB{ + {Name: tt.update.id.Name, Tags: tt.update.id.Tags, V: tt.update.v}, + }}) + case "gauge": + m.updateMetrics(&pb.MetricsPB{Gauges: measurements}) + m.updateMetrics(&pb.MetricsPB{Gauges: []*pb.U64PB{ + {Name: tt.update.id.Name, Tags: tt.update.id.Tags, V: tt.update.v}, + }}) + default: + t.Fatal("invalid type") + } + + got, exists := w.data[flatkey([]string{tt.update.id.Name}, m.toLabels(tt.update.id))] + if !exists { + t.Error("metric does not exist after update") + } + if math.Abs(float64(tt.expected-got)) > 0.001 { + t.Errorf("update was %v, want %v", got, tt.expected) + } + }) + } + } +} diff --git a/host/enclave.config.sample b/host/enclave.config.sample new file mode 100644 index 0000000..4484bfd --- /dev/null +++ b/host/enclave.config.sample @@ -0,0 +1,22 @@ +enclave_config { + raft { + election_ticks: 30 + heartbeat_ticks: 15 + replication_chunk_bytes: 1048576 + replica_voting_timeout_ticks: 60 + replica_membership_timeout_ticks: 300 + log_max_bytes: 104857600 + replication_pipeline: 32 + } + e2e_txn_timeout_ticks: 30 + send_timestamp_ticks: 60 +} +initial_log_level: LOG_LEVEL_INFO +group_config { + min_voting_replicas: 1 + max_voting_replicas: 5 + super_majority: 0 + db_version: DATABASE_VERSION_SVR2 + attestation_timeout: 86400 + simulated: true +} diff --git a/host/enclave/callback.go b/host/enclave/callback.go new file mode 100644 index 0000000..78a08d3 --- /dev/null +++ b/host/enclave/callback.go @@ -0,0 +1,22 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package enclave + +import ( + "reflect" + "unsafe" +) + +// #include +import "C" + +//export svr2OutputMessageGoCallback +func svr2OutputMessageGoCallback(size C.size_t, msg *C.uchar) { + var msgSlice []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&msgSlice)) + hdr.Len = int(size) + hdr.Cap = int(size) + hdr.Data = uintptr(unsafe.Pointer(msg)) + receiveMessage(msgSlice) +} diff --git a/host/enclave/enclave_test.go b/host/enclave/enclave_test.go new file mode 100644 index 0000000..0856264 --- /dev/null +++ b/host/enclave/enclave_test.go @@ -0,0 +1,98 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package enclave + +import ( + "testing" + "time" + + pb "github.com/signalapp/svr2/proto" +) + +var ( + validConfig = pb.InitConfig{ + EnclaveConfig: &pb.EnclaveConfig{ + Raft: &pb.RaftConfig{ + ElectionTicks: 30, + HeartbeatTicks: 15, + ReplicationChunkBytes: 1 << 20, + ReplicaVotingTimeoutTicks: 120, + ReplicaMembershipTimeoutTicks: 240, + LogMaxBytes: 10 << 20, + }, + E2ETxnTimeoutTicks: 30, + }, + GroupConfig: &pb.RaftGroupConfig{ + DbVersion: pb.DatabaseVersion_DATABASE_VERSION_SVR2, + MinVotingReplicas: 1, + MaxVotingReplicas: 5, + AttestationTimeout: 3600, + Simulated: true, + }, + } +) + +func TestSimulatedEnclave(t *testing.T) { + sgx := SGXEnclave() + if err := sgx.Init("../../enclave/build/enclave.test", &validConfig); err != nil { + t.Fatal(err) + } + c := sgx.OutputMessages() + // Create and close a client. + if err := sgx.SendMessage(&pb.UntrustedMessage{ + Inner: &pb.UntrustedMessage_H2ERequest{ + H2ERequest: &pb.HostToEnclaveRequest{ + RequestId: 1, + Inner: &pb.HostToEnclaveRequest_NewClient{ + NewClient: &pb.NewClientRequest{}, + }, + }, + }, + }); err != nil { + t.Fatalf("sending new client request: %v", err) + } + var clientID uint64 + select { + case msg := <-c: + if m, ok := msg.Inner.(*pb.EnclaveMessage_H2EResponse); !ok { + t.Fatalf("not EnclaveMessage_H2EResponse: %v", msg) + } else if nc, ok := m.H2EResponse.Inner.(*pb.HostToEnclaveResponse_NewClientReply); !ok { + t.Fatalf("not HostToEnclaveResponse_NewClientReply: %v", msg) + } else { + clientID = nc.NewClientReply.ClientId + } + case <-time.After(time.Second * 5): + t.Fatal("took >5s to get response") + } + if err := sgx.SendMessage(&pb.UntrustedMessage{ + Inner: &pb.UntrustedMessage_H2ERequest{ + H2ERequest: &pb.HostToEnclaveRequest{ + RequestId: 2, + Inner: &pb.HostToEnclaveRequest_CloseClient{ + CloseClient: &pb.CloseClientRequest{ + ClientId: clientID, + }, + }, + }, + }, + }); err != nil { + t.Fatalf("sending client close: %v", err) + } + select { + case msg := <-c: + if m, ok := msg.Inner.(*pb.EnclaveMessage_H2EResponse); !ok { + t.Fatalf("not EnclaveMessage_H2EResponse: %v", msg) + } else if s, ok := m.H2EResponse.Inner.(*pb.HostToEnclaveResponse_Status); !ok { + t.Fatalf("not HostToEnclaveResponse_NewClientReply: %v", msg) + } else if s.Status != pb.Error_OK { + t.Fatalf("close status, want %v got %v", pb.Error_OK, s.Status) + } + case <-time.After(time.Second * 5): + t.Fatal("took >5s to get response") + } + + go sgx.Close() + // Make sure that Close() actually closes the output channel. + <-c +} diff --git a/host/enclave/iface.go b/host/enclave/iface.go new file mode 100644 index 0000000..b24f325 --- /dev/null +++ b/host/enclave/iface.go @@ -0,0 +1,17 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package enclave + +import ( + "github.com/signalapp/svr2/peerid" + + pb "github.com/signalapp/svr2/proto" +) + +type Enclave interface { + PID() peerid.PeerID + OutputMessages() <-chan *pb.EnclaveMessage + SendMessage(msgPB *pb.UntrustedMessage) error + Close() +} diff --git a/host/enclave/logging.go b/host/enclave/logging.go new file mode 100644 index 0000000..131db67 --- /dev/null +++ b/host/enclave/logging.go @@ -0,0 +1,52 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package enclave + +import ( + "unsafe" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +// #include +// #include +// #include +// #include "c/svr2_u.h" +import "C" + +//export svr2LogCallback +func svr2LogCallback( + context unsafe.Pointer, + is_enclave bool, + t *C.struct_tm, + usecs C.long, + level C.oe_log_level_t, + host_thread_id uint64, + msg *C.char) { + + var zapLevel zapcore.Level + switch level { + case C.OE_LOG_LEVEL_NONE: + zapLevel = zapcore.ErrorLevel + case C.OE_LOG_LEVEL_FATAL: + // use error for fatal, fatal in go means panic + zapLevel = zapcore.ErrorLevel + case C.OE_LOG_LEVEL_ERROR: + zapLevel = zapcore.ErrorLevel + case C.OE_LOG_LEVEL_WARNING: + zapLevel = zapcore.WarnLevel + case C.OE_LOG_LEVEL_INFO: + zapLevel = zapcore.InfoLevel + case C.OE_LOG_LEVEL_VERBOSE: + zapLevel = zapcore.DebugLevel + case C.OE_LOG_LEVEL_MAX: + zapLevel = zapcore.DebugLevel + default: + zapLevel = zapcore.ErrorLevel + } + + zap.L().Log(zapLevel, C.GoString(msg), + zap.Uint64("host_thread", host_thread_id)) +} diff --git a/host/enclave/nitro.go b/host/enclave/nitro.go new file mode 100644 index 0000000..f84432c --- /dev/null +++ b/host/enclave/nitro.go @@ -0,0 +1,204 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package enclave + +import ( + "encoding/binary" + "fmt" + "io" + "log" + "net" + "os" + "sync" + "sync/atomic" + "syscall" + "time" + + "github.com/signalapp/svr2/peerid" + "google.golang.org/protobuf/proto" + + pb "github.com/signalapp/svr2/proto" +) + +// #include +// #include +import "C" + +func vSock() (net.Conn, error) { + fd, err := syscall.Socket(C.AF_VSOCK, syscall.SOCK_STREAM, 0) + if err != nil { + return nil, err + } + // TODO: `connect` this socket. + f := os.NewFile(uintptr(fd), "vsock") + return net.FileConn(f) +} + +type Nitro struct { + c chan *pb.EnclaveMessage + pid peerid.PeerID + + rMu, wMu sync.Mutex + sock net.Conn + + msgIDGen uint64 + callMu sync.Mutex + calls map[uint64]chan<- error +} + +var testNitroIface Enclave = (*Nitro)(nil) + +func (n *Nitro) send(pb proto.Message) error { + buf, err := proto.Marshal(pb) + if err != nil { + return fmt.Errorf("marshaling proto: %w", err) + } + var sizeBuf [4]byte + binary.BigEndian.PutUint32(sizeBuf[:], uint32(len(buf))) + + n.wMu.Lock() + defer n.wMu.Unlock() + if _, err := n.sock.Write(sizeBuf[:]); err != nil { + return fmt.Errorf("writing size: %w", err) + } else if _, err := n.sock.Write(buf); err != nil { + return fmt.Errorf("writing proto: %v", err) + } + return nil +} + +func (n *Nitro) recv(pb proto.Message) error { + n.rMu.Lock() + defer n.rMu.Unlock() + var sizeBuf [4]byte + if _, err := io.ReadFull(n.sock, sizeBuf[:]); err != nil { + return fmt.Errorf("reading size: %w", err) + } + size := binary.BigEndian.Uint32(sizeBuf[:]) + buf := make([]byte, size) + if _, err := io.ReadFull(n.sock, buf); err != nil { + return fmt.Errorf("reading proto: %w", err) + } else if err := proto.Unmarshal(buf, pb); err != nil { + return fmt.Errorf("unmarshaling proto: %w", err) + } + return nil +} + +func NewNitro(config *pb.InitConfig) (_ *Nitro, returnedErr error) { + sock, err := vSock() + if err != nil { + return nil, fmt.Errorf("creating vsock: %w", err) + } + n := &Nitro{ + c: make(chan *pb.EnclaveMessage, 100), + sock: sock, + } + defer func() { + if returnedErr != nil { + n.sock.Close() + } + }() + config.InitialTimestampUnixSecs = uint64(time.Now().Unix()) + initReq := &pb.InboundMessage{ + Inner: &pb.InboundMessage_Init{Init: config}, + } + if err := n.send(initReq); err != nil { + return nil, err + } + var initResp pb.OutboundMessage + if err := n.recv(&initResp); err != nil { + return nil, fmt.Errorf("init recv: %w", err) + } else if inner, ok := initResp.Inner.(*pb.OutboundMessage_Init); !ok { + return nil, fmt.Errorf("init response was not type InitCallResponse") + } else if n.pid, err = peerid.Make(inner.Init.PeerId); err != nil { + return nil, fmt.Errorf("init received peerid: %w", err) + } + go n.readOutputs() + return n, nil +} + +func (n *Nitro) readNextOutput() error { + var out pb.OutboundMessage + if err := n.recv(&out); err != nil { + return fmt.Errorf("recv error: %w", err) + } + switch v := out.Inner.(type) { + case *pb.OutboundMessage_Init: + return fmt.Errorf("received init") + case *pb.OutboundMessage_Msg: + if err := n.receivedResponse(v.Msg); err != nil { + return fmt.Errorf("received response error: %w", err) + } + case *pb.OutboundMessage_Out: + var emsg pb.EnclaveMessage + if err := proto.Unmarshal(v.Out, &emsg); err != nil { + return fmt.Errorf("unmarshal of EnclaveMessage: %w", err) + } + n.c <- &emsg + default: + return fmt.Errorf("unexpected inner type %T", out.Inner) + } + return nil +} + +func (n *Nitro) receivedResponse(m *pb.MsgCallResponse) error { + n.callMu.Lock() + defer n.callMu.Unlock() + done := n.calls[m.Id] + if done == nil { + return fmt.Errorf("received response to msg %d which isn't an active call", m.Id) + } + delete(n.calls, m.Id) + done <- m.Status // should not block, since done is a buffered channel + return nil +} + +func (n *Nitro) readOutputs() { + var err error + for err == nil { + err = n.readNextOutput() + } + log.Printf("nitro readOutputs failure: %v", err) + n.sock.Close() + close(n.c) +} + +func (n *Nitro) PID() peerid.PeerID { + return n.pid +} + +func (n *Nitro) OutputMessages() <-chan *pb.EnclaveMessage { + return n.c +} + +func (n *Nitro) SendMessage(msgPB *pb.UntrustedMessage) error { + buf, err := proto.Marshal(msgPB) + if err != nil { + return fmt.Errorf("marshaling: %w", err) + } + id := atomic.AddUint64(&n.msgIDGen, 1) + in := pb.InboundMessage{ + Inner: &pb.InboundMessage_Msg{ + Msg: &pb.MsgCallRequest{ + Id: id, + Data: buf, + }, + }, + } + done := make(chan error, 1) + n.callMu.Lock() + n.calls[id] = done + n.callMu.Unlock() + if err := n.send(&in); err != nil { + n.callMu.Lock() + delete(n.calls, id) + n.callMu.Unlock() + n.sock.Close() // a failure to send is a permanent failure + return fmt.Errorf("sending: %w", err) + } + return <-done +} + +func (n *Nitro) Close() { + n.sock.Close() +} diff --git a/host/enclave/sgx.go b/host/enclave/sgx.go new file mode 100644 index 0000000..5e910ec --- /dev/null +++ b/host/enclave/sgx.go @@ -0,0 +1,223 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +// Package enclave connects the Go binary to C code. +// +// This package connects up ocalls/ecalls and associated OpenEnclave APIs and +// exposes them in a Go-y way to the rest of the project. +package enclave + +/* +#cgo pkg-config: oehost-g++ +#cgo LDFLAGS: -L./c -lsvr2 +#include "c/svr2_u.h" +#include +#include + +// Defined in `callback.go` +void svr2OutputMessageGoCallback(size_t msg_size, unsigned char* msg); + +void svr2_output_message(size_t msg_size, unsigned char* msg) { + svr2OutputMessageGoCallback(msg_size, msg); +} + +void svr2LogCallback( + void* context, + bool is_enclave, + const struct tm* t, + long usecs, + oe_log_level_t level, + uint64_t host_thread_id, + const char* message); + +int setUpLoggingInC() { + return oe_log_set_callback(0, svr2LogCallback); +} + +int setSignalOnStack(int signal) { + int ret = 0; + struct sigaction action = {0}; + if (0 != (ret = sigaction(signal, 0, &action))) { return ret; } + action.sa_flags |= SA_ONSTACK; + if (0 != (ret = sigaction(signal, &action, 0))) { return ret; } + return 0; +} +*/ +import "C" + +import ( + "fmt" + "reflect" + "sync" + "time" + "unsafe" + + "github.com/signalapp/svr2/logger" + "github.com/signalapp/svr2/peerid" + "google.golang.org/protobuf/proto" + + pb "github.com/signalapp/svr2/proto" +) + +type EnclaveType int + +const ( + MESSAGE_BUFFER_SIZE = 100 + PRODUCTION EnclaveType = 0 + SIMULATED EnclaveType = C.OE_ENCLAVE_FLAG_SIMULATE +) + +type SGX struct { + ptr *C.struct__oe_enclave + mu sync.RWMutex // lock for enclave + msgSendOutputs chan *pb.EnclaveMessage + pid peerid.PeerID +} + +var sgxSingleton SGX + +// OpenEnclaveError is an error that wraps an oe_result_t. +type OpenEnclaveError C.oe_result_t + +// Error implements error. +func (c OpenEnclaveError) Error() string { + return fmt.Sprintf("error calling enclave: %d==0x%x %s", uint(c), uint(c), C.GoString(C.oe_result_str(C.oe_result_t(c)))) +} + +// ReturnedError is an error that wraps a SVR2 error. +type ReturnedError C.int + +// Error implements error. +func (r ReturnedError) Error() string { + return fmt.Sprintf("error returned from enclave: %d", uint(r)) +} + +// Instance returns a live, singleton enclave object that wraps the actual SGX interface. +func SGXEnclave() *SGX { return &sgxSingleton } + +var setLoggingOnce sync.Once +var testSGXInterface Enclave = &sgxSingleton + +func setUpLogging() { + if C.OE_OK != C.setUpLoggingInC() { + panic("setup of logging failed") + } +} + +func (s *SGX) PID() peerid.PeerID { + return s.pid +} + +func (s *SGX) OutputMessages() <-chan *pb.EnclaveMessage { + return s.msgSendOutputs +} + +// The channel guarantees that messages are sent in the order that the enclave sends them. +// Init intializes the enclave. +// [path] is the path to file containing the compiled enclave object to run. +// [config] is the enclave configuration to use. +// +// Init returns a channel of messages that are written to the host by the enclave. +// It's up to the caller to empty and process these messages; failing to do so +// will eventually block the enclave, as the buffer (size MESSAGE_BUFFER_SIZE) +// fills. The channel will be closed as part of Close(), once the enclave has +// been fully shut down. +// +// The channel guarantees that messages are sent in the order that the enclave sends them. +func (s *SGX) Init(path string, config *pb.InitConfig) (returnedError error) { + s.mu.Lock() + defer s.mu.Unlock() + var pid peerid.PeerID + if s.ptr != nil { + return fmt.Errorf("enclave already initiated") + } + setLoggingOnce.Do(setUpLogging) + + var typ EnclaveType = PRODUCTION + if config.GroupConfig.Simulated { + typ = SIMULATED + } + + config.InitialTimestampUnixSecs = uint64(time.Now().Unix()) + configBytes, err := proto.Marshal(config) + if err != nil { + return fmt.Errorf("marshaling config proto: %v", err) + } + s.msgSendOutputs = make(chan *pb.EnclaveMessage, MESSAGE_BUFFER_SIZE) + pathC := C.CString(path) + defer C.free(unsafe.Pointer(pathC)) + if err := C.oe_create_svr2_enclave(pathC, C.OE_ENCLAVE_TYPE_SGX, C.OE_ENCLAVE_FLAG_DEBUG_AUTO|C.uint(typ), nil, 0, &s.ptr); err != C.OE_OK { + return OpenEnclaveError(err) + } + if s.ptr == nil { + panic("got nil s.ptr") + } + defer func() { + if returnedError != nil { + s.Close() + } + }() + + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&configBytes)) + var ret C.int + if oeErr := C.svr2_init(s.ptr, &ret, C.ulong(hdr.Len), (*C.uchar)(unsafe.Pointer(hdr.Data)), (*C.uchar)(unsafe.Pointer(&pid[0]))); oeErr != 0 { + return OpenEnclaveError(oeErr) + } else if ret != 0 { + return ReturnedError(ret) + } + s.pid = pid + + // Go requires that signal handlers use SA_ONSTACK, but OpenEnclave sets up some + // signal handlers without this flag. Reset them. + // Signals taken from OpenEnclave's host/sgx/linux/exception.c + for _, signal := range []C.int{C.SIGBUS, C.SIGFPE, C.SIGILL, C.SIGSEGV, C.SIGTRAP, C.SIGHUP, C.SIGABRT, C.SIGALRM, C.SIGPIPE, C.SIGPOLL, C.SIGUSR1, C.SIGUSR2} { + if r := C.setSignalOnStack(signal); r != 0 { + return fmt.Errorf("setting onstack for signal %d failed: %d", signal, r) + } + } + + return nil +} + +// SendMessage sends a message to the running enclave. Messages generated by +// the enclave during the lifetime of this call will be made available on the +// channel provided by Init. +func (s *SGX) SendMessage(msgPB *pb.UntrustedMessage) error { + msg, err := proto.Marshal(msgPB) + if err != nil { + return err + } + s.mu.RLock() + defer s.mu.RUnlock() + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&msg)) + var ret C.int + if oeErr := C.svr2_input_message(s.ptr, &ret, C.ulong(hdr.Len), (*C.uchar)(unsafe.Pointer(hdr.Data))); oeErr != 0 { + return OpenEnclaveError(oeErr) + } else if ret != 0 { + return ReturnedError(ret) + } + return nil +} + +// Close terminates and releases resources for the enclave. +func (s *SGX) Close() { + s.mu.Lock() + defer s.mu.Unlock() + C.oe_terminate_enclave(s.ptr) + s.ptr = nil + close(s.msgSendOutputs) + s.msgSendOutputs = nil +} + +// receiveMessage is called by svr2_output_message during svr2_input_message calls. +func receiveMessage(buf []byte) { + var msg pb.EnclaveMessage + if err := proto.Unmarshal(buf, &msg); err != nil { + logger.Errorf("This is a severe bug. Could not unmarshal a message from the enclave. dropping : %v", err) + return + } + // Check the precondition that the mainThread lock should be locked already. + sgxSingleton.mu.RLock() + sgxSingleton.msgSendOutputs <- &msg + sgxSingleton.mu.RUnlock() +} diff --git a/host/go.mod b/host/go.mod new file mode 100644 index 0000000..44a6110 --- /dev/null +++ b/host/go.mod @@ -0,0 +1,39 @@ +module github.com/signalapp/svr2 + +go 1.19 + +require ( + github.com/alicebob/miniredis/v2 v2.23.1 + github.com/armon/go-metrics v0.4.1 + github.com/go-redis/redis/v8 v8.11.5 + github.com/go-redis/redis_rate/v9 v9.1.2 + golang.org/x/exp v0.0.0-20230213192124-5e25df0256eb + golang.org/x/sync v0.1.0 +) + +require ( + github.com/DataDog/datadog-go v4.8.3+incompatible // indirect + github.com/Microsoft/go-winio v0.6.0 // indirect + github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/hashicorp/go-immutable-radix v1.3.1 // indirect + github.com/hashicorp/golang-lru v0.5.0 // indirect + github.com/yuin/gopher-lua v0.0.0-20220504180219-658193537a64 // indirect + golang.org/x/crypto v0.1.0 // indirect + golang.org/x/mod v0.8.0 // indirect + golang.org/x/net v0.7.0 // indirect + golang.org/x/sys v0.5.0 // indirect + golang.org/x/tools v0.6.0 // indirect +) + +require ( + github.com/flynn/noise v1.0.0 + github.com/google/go-cmp v0.5.9 + github.com/gorilla/websocket v1.5.0 + go.uber.org/atomic v1.10.0 // indirect + go.uber.org/multierr v1.9.0 // indirect + go.uber.org/zap v1.24.0 + google.golang.org/protobuf v1.28.1 + gopkg.in/yaml.v2 v2.4.0 +) diff --git a/host/go.sum b/host/go.sum new file mode 100644 index 0000000..0d8a508 --- /dev/null +++ b/host/go.sum @@ -0,0 +1,175 @@ +github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/DataDog/datadog-go v4.8.3+incompatible h1:fNGaYSuObuQb5nzeTQqowRAd9bpDIRRV4/gUtIBjh8Q= +github.com/DataDog/datadog-go v4.8.3+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/Microsoft/go-winio v0.6.0 h1:slsWYD/zyx7lCXoZVlvQrj0hPTM1HI4+v1sIda2yDvg= +github.com/Microsoft/go-winio v0.6.0/go.mod h1:cTAf44im0RAYeL23bpB+fzCyDH2MJiz2BO69KH/soAE= +github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk= +github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= +github.com/alicebob/miniredis/v2 v2.23.1 h1:jR6wZggBxwWygeXcdNyguCOCIjPsZyNUNlAkTx2fu0U= +github.com/alicebob/miniredis/v2 v2.23.1/go.mod h1:84TWKZlxYkfgMucPBf5SOQBYJceZeQRFIaQgNMiCX6Q= +github.com/armon/go-metrics v0.4.1 h1:hR91U9KYmb6bLBYLQjyM+3j+rcd/UhE+G78SFnF8gJA= +github.com/armon/go-metrics v0.4.1/go.mod h1:E6amYzXo6aW1tqzoZGT755KkbgrJsSdpwZ+3JqfkOG4= +github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= +github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= +github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= +github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= +github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= +github.com/circonus-labs/circonus-gometrics v2.3.1+incompatible/go.mod h1:nmEj6Dob7S7YxXgwXpfOuvO54S+tGdZdw9fuRZt25Ag= +github.com/circonus-labs/circonusllhist v0.1.3/go.mod h1:kMXHVDlOchFAehlya5ePtbp5jckzBHf4XRpQvBOLI+I= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/flynn/noise v1.0.0 h1:DlTHqmzmvcEiKj+4RYo/imoswx/4r6iBlCMfVtrMXpQ= +github.com/flynn/noise v1.0.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= +github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= +github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= +github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= +github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= +github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= +github.com/go-redis/redis_rate/v9 v9.1.2 h1:H0l5VzoAtOE6ydd38j8MCq3ABlGLnvvbA1xDSVVCHgQ= +github.com/go-redis/redis_rate/v9 v9.1.2/go.mod h1:oam2de2apSgRG8aJzwJddXbNu91Iyz1m8IKJE2vpvlQ= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/hashicorp/go-cleanhttp v0.5.0/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= +github.com/hashicorp/go-immutable-radix v1.0.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= +github.com/hashicorp/go-immutable-radix v1.3.1 h1:DKHmCUm2hRBK510BaiZlwvpD40f8bJFeZnpfm2KLowc= +github.com/hashicorp/go-immutable-radix v1.3.1/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= +github.com/hashicorp/go-retryablehttp v0.5.3/go.mod h1:9B5zBasrRhHXnJnui7y6sL7es7NDiJgTc6Er0maI1Xs= +github.com/hashicorp/go-uuid v1.0.0 h1:RS8zrF7PhGwyNPOtxSClXXj9HA8feRnJzgnI1RJCSnM= +github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hashicorp/golang-lru v0.5.0 h1:CL2msUPvZTLb5O648aiLNJw3hnBxN2+1Jq8rCOH9wdo= +github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= +github.com/pascaldekloe/goe v0.1.0 h1:cBOtyMzM9HTpWjXfbbunk26uA6nG3a8n06Wieeh0MwY= +github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= +github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= +github.com/prometheus/client_golang v1.4.0/go.mod h1:e9GMxYsXl05ICDXkRhurwBS4Q3OK1iX/F2sw+iXX5zU= +github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= +github.com/prometheus/common v0.9.1/go.mod h1:yhUN8i9wzaXS3w1O07YhxHEBxD+W35wd8bs7vj7HSQ4= +github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= +github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A= +github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1 h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM= +github.com/yuin/gopher-lua v0.0.0-20220504180219-658193537a64 h1:5mLPGnFdSsevFRFc9q3yYbBkB6tsm4aCwwQV/j1JQAQ= +github.com/yuin/gopher-lua v0.0.0-20220504180219-658193537a64/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= +go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= +go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI= +go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= +go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= +go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60= +go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= +golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU= +golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= +golang.org/x/exp v0.0.0-20230213192124-5e25df0256eb h1:PaBZQdo+iSDyHT053FjUCgZQ/9uqVwPOcl7KSWhKn6w= +golang.org/x/exp v0.0.0-20230213192124-5e25df0256eb/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= +golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= +google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/host/health/health.go b/host/health/health.go new file mode 100644 index 0000000..00c9625 --- /dev/null +++ b/host/health/health.go @@ -0,0 +1,43 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package health + +import ( + "fmt" + "net/http" + "sync" +) + +// Health wraps an error (nil means "healthy"), and provides HTTP handling +// logic to serve that error. +type Health struct { + mu sync.Mutex + err error +} + +// New creates a new health object, with initial health set based on the +// 'initial' error (nil==healthy). +func New(initial error) *Health { + return &Health{err: initial} +} + +// Set sets the underlying error for this Health object; err=nil means "OK" +func (h *Health) Set(err error) { + h.mu.Lock() + h.err = err + h.mu.Unlock() +} + +// ServeHTTP implements http.Handler. +func (h *Health) ServeHTTP(w http.ResponseWriter, r *http.Request) { + h.mu.Lock() + err := h.err + h.mu.Unlock() + if err == nil { + fmt.Fprintf(w, "ok") + return + } + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintf(w, "error: %v", err) +} diff --git a/host/health/health_test.go b/host/health/health_test.go new file mode 100644 index 0000000..9445dbd --- /dev/null +++ b/host/health/health_test.go @@ -0,0 +1,53 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package health + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" +) + +func TestServingFromHealthy(t *testing.T) { + h := New(nil) + ts := httptest.NewServer(h) + defer ts.Close() + res, err := http.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != http.StatusOK { + t.Errorf("nil error returned status: %v", res.Status) + } + h.Set(errors.New("FUBAR")) + res, err = http.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != http.StatusInternalServerError { + t.Errorf("non-nil error returned status: %v", res.Status) + } +} + +func TestServingFromUnhealthy(t *testing.T) { + h := New(errors.New("FUBAR")) + ts := httptest.NewServer(h) + defer ts.Close() + res, err := http.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != http.StatusInternalServerError { + t.Errorf("non-nil error returned status: %v", res.Status) + } + h.Set(nil) + res, err = http.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != http.StatusOK { + t.Errorf("nil error returned status: %v", res.Status) + } +} diff --git a/host/host.config.sample b/host/host.config.sample new file mode 100644 index 0000000..197d3e6 --- /dev/null +++ b/host/host.config.sample @@ -0,0 +1,13 @@ +clientListenAddr: localhost:8080 +controlListenAddr: localhost:8081 +peerAddr: localhost:8082 +enclaveId: 17e1cb662572d28e0eb5a492ed8df949bc2cfcf3f2098b710e7b637759d6dcb3 + +raft: + tickDuration: 1s + metricPollDuration: 10s + refreshStatusDuration: 15s + enclaveConcurrency: 8 + +redis: + addrs: [localhost:9999] diff --git a/host/integration/integration_test.go b/host/integration/integration_test.go new file mode 100644 index 0000000..487588b --- /dev/null +++ b/host/integration/integration_test.go @@ -0,0 +1,325 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package service + +import ( + "bytes" + "context" + "encoding/base64" + "flag" + "fmt" + "io" + "log" + "net/http" + "net/url" + "os" + "os/exec" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/gorilla/websocket" + "github.com/signalapp/svr2/auth" + "github.com/signalapp/svr2/servicetest" + "golang.org/x/sync/errgroup" + + pb "github.com/signalapp/svr2/proto" +) + +// These tests spin up a full SVR cluster across multiple processes. They require +// a compiled host binary to function, which can be explicitly provided via CLI flag + +var ( + enclavePath = flag.String("enclave_path", "../../enclave/build/enclave.test", "Path to binary holding the enclave") + hostPath = flag.String("host_path", "../main", "Path to go host binary") + econfigPath = flag.String("econfig_path", "testdata/enclave.config", "Path to enclave configuration") + numNodes = flag.Int("num_nodes", 3, "Number of nodes in the raft group") + + svrGroup group + authSecret = "123456" + data = []byte("some test data. must be at least 16 bytes") +) + +func userName(i int) string { + return fmt.Sprintf("%032x", i) +} + +func TestIntegration(t *testing.T) { + host := fmt.Sprintf("localhost:%v", port(clientType, 1)) + u := url.URL{Scheme: "ws", Host: host, Path: "v1/enclave"} + pin := backup(t, testClient(t, u, userName(9999))) + expose(t, testClient(t, u, userName(9999))) + restore(t, testClient(t, u, userName(9999)), pin) +} + +func TestConcurrentClients(t *testing.T) { + host := fmt.Sprintf("localhost:%v", port(clientType, 1)) + u := url.URL{Scheme: "ws", Host: host, Path: "v1/enclave"} + + wg := sync.WaitGroup{} + for i := 0; i < 10; i++ { + user := userName(i) + wg.Add(1) + go func() { + defer wg.Done() + pin := backup(t, testClient(t, u, user)) + expose(t, testClient(t, u, user)) + restore(t, testClient(t, u, user), pin) + }() + } + wg.Wait() +} + +func TestServerDelete(t *testing.T) { + user := userName(123) + host := fmt.Sprintf("localhost:%v", port(clientType, 1)) + u := url.URL{Scheme: "ws", Host: host, Path: "v1/enclave"} + pin := backup(t, testClient(t, u, user)) + expose(t, testClient(t, u, user)) + restore(t, testClient(t, u, user), pin) + + // use the server http endpoint to delete the backup for user + deleteURL := fmt.Sprintf("http://localhost:%v/v1/delete", port(clientType, 2)) + req, err := http.NewRequest(http.MethodDelete, deleteURL, nil) + if err != nil { + t.Fatal(err) + } + req.Header = authHeaders(user) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("delete failed: %v", string(body)) + } + + // use the websocket to check that it's really gone + tc := testClient(t, u, user) + r := tc.Send(&pb.Request{Inner: &pb.Request_Restore{ + Restore: &pb.RestoreRequest{ + Pin: pin, + }, + }}) + if rr, ok := r.Inner.(*pb.Response_Restore); !ok { + t.Fatalf("Unexpected response to restore: %v", r) + } else if rr.Restore.Status != pb.RestoreResponse_MISSING { + t.Fatalf("Incorrect response: %v, backup should be missing", rr) + } +} + +func authHeaders(user string) http.Header { + authenticator := auth.New([]byte(authSecret)) + headers := http.Header{} + headers.Set("Authorization", "Basic "+base64.URLEncoding.EncodeToString([]byte(user+":"+authenticator.PassFor(user)))) + return headers +} + +func testClient(t *testing.T, u url.URL, user string) *servicetest.TestClient { + headers := authHeaders(user) + log.Printf("using headers: %+v", headers) + c, _, err := websocket.DefaultDialer.Dial(u.String(), headers) + if err != nil { + t.Fatalf("dial: %v", err) + } + + return servicetest.NewTestClient(t, c) +} + +func backup(t *testing.T, tc *servicetest.TestClient) (pin []byte) { + // send a backup request + defer tc.Sc.Close() + pin = servicetest.RandBytes(t, 32) + + r := tc.Send(&pb.Request{Inner: &pb.Request_Backup{ + Backup: &pb.BackupRequest{ + Data: data, + Pin: pin, + MaxTries: 5, + }, + }}) + if br, ok := r.Inner.(*pb.Response_Backup); !ok { + t.Fatalf("Unexpected response to backup: %v", r) + } else if br.Backup.Status != pb.BackupResponse_OK { + t.Fatalf("Incorrect response: %v", br) + } + return +} + +func expose(t *testing.T, tc *servicetest.TestClient) { + defer tc.Sc.Close() + // send a expose request + r := tc.Send(&pb.Request{Inner: &pb.Request_Expose{ + Expose: &pb.ExposeRequest{ + Data: data, + }, + }}) + if br, ok := r.Inner.(*pb.Response_Expose); !ok { + t.Fatalf("Unexpected response to expose: %v", r) + } else if br.Expose.Status != pb.ExposeResponse_OK { + t.Fatalf("Incorrect response: %v", br) + } +} + +func restore(t *testing.T, tc *servicetest.TestClient, pin []byte) { + defer tc.Sc.Close() + // send a restore request + r := tc.Send(&pb.Request{Inner: &pb.Request_Restore{ + Restore: &pb.RestoreRequest{ + Pin: pin, + }, + }}) + if rr, ok := r.Inner.(*pb.Response_Restore); !ok { + t.Fatalf("Unexpected response to restore: %v", r) + } else if rr.Restore.Status != pb.RestoreResponse_OK { + t.Fatalf("Incorrect response: %v", rr) + } else if !bytes.Equal(rr.Restore.Data, data) { + t.Fatalf("Restored bytes %v, want %v", rr.Restore.Data, data) + } +} + +func initializeAndRun(m *testing.M) int { + svrGroup = start() + defer svrGroup.stop() + return m.Run() + +} + +func TestMain(m *testing.M) { + flag.Parse() + if testing.Short() { + return + } + os.Exit(initializeAndRun(m)) +} + +type group struct { + cancel context.CancelFunc + ctx context.Context + dir string + eg *errgroup.Group +} + +type addrType int + +const ( + controlType addrType = iota + clientType + peerType +) + +func port(typ addrType, portOffset int) int { + switch typ { + case controlType: + return 8090 + portOffset + case clientType: + return 8080 + portOffset + case peerType: + return 9000 + portOffset + } + return 0 +} + +func hconfig(w io.Writer, dir string, portOffset int, redisAddr string) error { + _, err := fmt.Fprintf(w, ` +peerAddr: localhost:%v +clientListenAddr: localhost:%v +controlListenAddr: localhost:%v +raft: + tickDuration: 100ms +redis: + addrs: [%v]`, + port(peerType, portOffset), + port(clientType, portOffset), + port(controlType, portOffset), + redisAddr) + return err +} + +func (g *group) stop() { + g.cancel() + g.eg.Wait() + os.RemoveAll(g.dir) +} + +// keep trying ready endpoint on each node until 200 or timeout +func waitForReady() error { + for portOffset := 1; portOffset <= *numNodes; portOffset++ { + url := fmt.Sprintf("http://localhost:%v/health/ready", port(controlType, portOffset)) + if err := servicetest.WaitFor200(time.Minute, url); err != nil { + return err + } + } + return nil +} + +func start() group { + ctx, cancel := context.WithCancel(context.Background()) + eg, ctx := errgroup.WithContext(ctx) + + dir, err := os.MkdirTemp("", "host") + if err != nil { + log.Fatal(err) + } + + redis, err := miniredis.Run() + if err != nil { + log.Fatal(err) + } + defer redis.Close() + + // start numNodes SVR processes + for i := 1; i <= *numNodes; i++ { + f, err := os.Create(filepath.Join(dir, fmt.Sprintf("hostconfig-%d", i))) + if err != nil { + log.Fatalf("create tempfile : %v", err) + } + if err := hconfig(f, dir, i, redis.Addr()); err != nil { + log.Fatalf("write config : %v", err) + } + f.Close() + + cmd := exec.CommandContext( + ctx, + *hostPath, + "-hconfig_path", f.Name(), + "-econfig_path", *econfigPath, + "-enclave_path", *enclavePath) + cmd.Env = append(cmd.Env, "AUTH_SECRET="+base64.StdEncoding.EncodeToString([]byte(authSecret))) + + stderr, err := cmd.StderrPipe() + if err != nil { + log.Fatal(err) + } + if err = cmd.Start(); err != nil { + log.Fatalf("cmd start: %v", err) + } + + peerIndex := i + eg.Go(func() error { + slurp, _ := io.ReadAll(stderr) + log.Printf("peer(%v): %s\n", peerIndex, slurp) + return cmd.Wait() + }) + } + + // wait for all nodes to join raft + err = waitForReady() + if err != nil { + cancel() + eg.Wait() + log.Fatal(err) + } + + return group{ + cancel: cancel, + ctx: ctx, + dir: dir, + eg: eg, + } + +} diff --git a/host/integration/testdata/enclave.config b/host/integration/testdata/enclave.config new file mode 100644 index 0000000..5bf272f --- /dev/null +++ b/host/integration/testdata/enclave.config @@ -0,0 +1,20 @@ +enclave_config { + raft { + election_ticks: 30 + heartbeat_ticks: 15 + replication_chunk_bytes: 1048576 + replica_voting_timeout_ticks: 60 + replica_membership_timeout_ticks: 300 + log_max_bytes: 10000000 + } + e2e_txn_timeout_ticks: 30 +} +initial_log_level: LOG_LEVEL_DEBUG +group_config { + min_voting_replicas: 1 + max_voting_replicas: 5 + super_majority: 0 + db_version: DATABASE_VERSION_SVR2 + attestation_timeout: 86400 + simulated: true +} diff --git a/host/logger/logger.go b/host/logger/logger.go new file mode 100644 index 0000000..b2e29ad --- /dev/null +++ b/host/logger/logger.go @@ -0,0 +1,60 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +// Package logger provides helper functions to configure and use a global logger +package logger + +import ( + "log" + + "github.com/signalapp/svr2/config" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +// init sets up reasonable logging defaults for tests / non-main +func init() { + Init(config.Default()) +} + +// Init configures global loggers accessed with logging functions in this module. +// Optionaly provided fields will bind some key/value pairs to the global logger +func Init(cfg *config.Config) { + z, err := cfg.Log.Build(zap.AddCallerSkip(1)) + if err != nil { + log.Fatalf("zap init: %v", err) + } + zap.ReplaceGlobals(z) +} + +// WithGlobal binds some key/value pairs to the global logger +func WithGlobal(fields ...zapcore.Field) { + zap.ReplaceGlobals(zap.L().With(fields...)) +} + +// With returns a logger with some bound key/value pairs +func With(keysAndValues ...interface{}) *Logger { + return &Logger{zap.S().With(keysAndValues...)} +} + +// Sync flushes any buffered logs. Applications should call Sync before program exit. +func Sync() { + zap.L().Sync() +} + +type Logger struct { + *zap.SugaredLogger +} + +// wrappers around sugared zap logging methods that use the zap global logger + +func Infow(msg string, keysAndValues ...interface{}) { zap.S().Infow(msg, keysAndValues...) } +func Infof(template string, args ...interface{}) { zap.S().Infof(template, args...) } +func Debugw(msg string, keysAndValues ...interface{}) { zap.S().Debugw(msg, keysAndValues...) } +func Debugf(template string, args ...interface{}) { zap.S().Debugf(template, args...) } +func Warnw(msg string, keysAndValues ...interface{}) { zap.S().Warnw(msg, keysAndValues...) } +func Warnf(template string, args ...interface{}) { zap.S().Warnf(template, args...) } +func Errorw(msg string, keysAndValues ...interface{}) { zap.S().Errorw(msg, keysAndValues...) } +func Errorf(template string, args ...interface{}) { zap.S().Errorf(template, args...) } +func Fatalw(msg string, keysAndValues ...interface{}) { zap.S().Fatalw(msg, keysAndValues...) } +func Fatalf(template string, args ...interface{}) { zap.S().Fatalf(template, args...) } diff --git a/host/main.go b/host/main.go new file mode 100644 index 0000000..c9153af --- /dev/null +++ b/host/main.go @@ -0,0 +1,97 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package main + +import ( + "context" + "encoding/base64" + "flag" + "os" + "os/signal" + + "github.com/armon/go-metrics/datadog" + "github.com/signalapp/svr2/auth" + "github.com/signalapp/svr2/config" + "github.com/signalapp/svr2/logger" + "github.com/signalapp/svr2/service" + "google.golang.org/protobuf/encoding/prototext" + + stdlog "log" + + metrics "github.com/armon/go-metrics" + pb "github.com/signalapp/svr2/proto" +) + +var ( + enclavePath = flag.String("enclave_path", "", "Path to binary holding the enclave") + econfigPath = flag.String("econfig_path", "", "Path to enclave configuration prototext file") + hconfigPath = flag.String("hconfig_path", "", "Path to host configuration yaml file") +) + +func main() { + flag.Parse() + + hconfig, err := config.Read(*hconfigPath) + if err != nil { + stdlog.Fatalf("could not read configuration: %v", err) + } + logger.Init(hconfig) + defer logger.Sync() + + // configure metrics + if hconfig.DatadogAgentHost != "" { + logger.Infof("initializing datadog at %v", hconfig.DatadogAgentHost) + sink, err := datadog.NewDogStatsdSink(hconfig.DatadogAgentHost, "") + if err != nil { + logger.Fatalf("error initializing statsd client: %v", err) + } + defer sink.Shutdown() + + // disable hostname tagging, this can be provided by the downstream sink + cfg := metrics.DefaultConfig("svr2") + cfg.EnableHostname = false + cfg.EnableHostnameLabel = false + + _, err = metrics.NewGlobal(cfg, sink) + if err != nil { + logger.Fatalf("error initializing metrics : %v", err) + } + } + authSecret, ok := os.LookupEnv("AUTH_SECRET") + if !ok { + logger.Fatalf("no auth secret env (AUTH_SECRET)") + } + authBytes, err := base64.StdEncoding.DecodeString(authSecret) + if err != nil { + logger.Fatalf("auth secret invalid base64: %v", err) + } + authenticator := auth.New(authBytes) + + var econfig pb.InitConfig + if configBytes, err := os.ReadFile(*econfigPath); err != nil { + logger.Fatalf("error reading config file %q: %v", *econfigPath, err) + } else if err = prototext.Unmarshal([]byte(os.ExpandEnv(string(configBytes))), &econfig); err != nil { + logger.Fatalf("error reading config (ASCII proto): %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + interrupts := make(chan os.Signal, 1) + signal.Notify(interrupts, os.Interrupt) + defer func() { + signal.Stop(interrupts) + cancel() + }() + + go func() { + select { + case <-interrupts: + logger.Infof("received interrupt, shutting down...") + cancel() + case <-ctx.Done(): + } + }() + + err = service.Start(ctx, &econfig, hconfig, *enclavePath, authenticator) + logger.Fatalw("Shutting down", "error", err) +} diff --git a/host/miniredis/miniredis.go b/host/miniredis/miniredis.go new file mode 100644 index 0000000..6a23b92 --- /dev/null +++ b/host/miniredis/miniredis.go @@ -0,0 +1,25 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +// Binary miniredis sets up a usable Redis port at --addr, good for simple local testing. +package main + +import ( + "flag" + "log" + + "github.com/alicebob/miniredis/v2" +) + +var ( + addr = flag.String("addr", "", "MiniRedis bind address") +) + +func main() { + flag.Parse() + r := miniredis.NewMiniRedis() + if err := r.StartAddr(*addr); err != nil { + log.Fatal(err) + } + select {} +} diff --git a/host/peer/client.go b/host/peer/client.go new file mode 100644 index 0000000..fb1778e --- /dev/null +++ b/host/peer/client.go @@ -0,0 +1,507 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package peer + +import ( + "bufio" + "context" + "errors" + "fmt" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/armon/go-metrics" + "github.com/signalapp/svr2/config" + "github.com/signalapp/svr2/logger" + "github.com/signalapp/svr2/peerid" + "github.com/signalapp/svr2/util" + "golang.org/x/sync/errgroup" + + pb "github.com/signalapp/svr2/proto" +) + +// PeerLookup provides a way to get a peer's hostname from its PeerID +type PeerLookup interface { + // Lookup takes a PeerID and attempts to find the associated hostname. If the peer is not + // found, the returned string and error will be nil + Lookup(context.Context, peerid.PeerID) (*string, error) +} + +// PeerClient can be used to send PeerMessages to remote peers +// +// PeerClient routes PeerMessages to dedicated goroutines (peerSenders) +// that handle the peer protocol handshake and append sequence numbers to every +// PeerMessage. +// +// Usually the peer client must store outbound messages until they are acked. If the TCP +// connection between client and server is terminated, the client must resend pending messages to the remote +// peer on reconnection. However, there are certain cases where a client will declare message bankruptcy, +// and drop these pending messages. This can happen when: +// - A caller is sending messages faster than the peerClient (or the remote peer server) can process them +// - We have been trying and failing to connect to a peer server for a substantial period of time +// +// In the first case, the client may try to keep the underlying TCP connection open, but drop pending messages. In the +// second case, the client will "abandon" the peer and not attempt to connect to the server until a reset is received. +// When messages are dropped, an error will be returned to the caller indicating that the enclave-to-enclave session +// must be re-established. +// +// ┌──────────────┐ ┌──────────────┐ +// │ │ │ require │ +// │ unknown peer │ │ reset │ +// │ │ │ │ +// └───────┬──────┘ └─┬───────────▲┘ +// │ │ │ +// any │ RST/SYN│ │ +// message │ │ │ +// │ │ │ +// ┌───────▼──────┐◄───────┘ │ +// │ │ │ +// │ buffered ├────────────────────┘ +// │ sending │ buffer full / +// └──────────────┘ server unresponsive +type PeerClient struct { + me peerid.PeerID // sender's peerID + cfg *config.PeerConfig // client configuration + peerLookup PeerLookup // fetches the remote endpoint associated with a PeerID + eg *errgroup.Group // indicates one of the child senders has closed or hit an unrecoverable error + ctx context.Context + + sendersMu sync.Mutex + senders map[peerid.PeerID]*peerSender // map of live peerSenders + abandonedPeers map[peerid.PeerID]bool // set of peers we've previously talked to but now have abandoned +} + +var ( + activeConnectionsGauge = []string{"peer", "client", "activeConnections"} + outboundQueueLengthGauge = []string{"peer", "client", "outboundQueueLength"} + connectAttemptCounter = []string{"peer", "client", "connectAttempt"} + abandonedPeerCounter = []string{"peer", "client", "abandon"} + resendCounter = []string{"peer", "client", "resend"} + sendCounter = []string{"peer", "client", "send"} + ackCounter = []string{"peer", "client", "ack"} + epochCounter = []string{"peer", "client", "epoch"} +) + +// NewPeerClient creates a PeerClient +func NewPeerClient( + ctx context.Context, + me peerid.PeerID, + peerLookup PeerLookup, + cfg *config.PeerConfig) *PeerClient { + + eg, ctx := errgroup.WithContext(ctx) + + eg.Go(func() error { + // finishes if the caller cancels or any child experiences an error + <-ctx.Done() + return ctx.Err() + }) + + return &PeerClient{ + me: me, + cfg: cfg, + peerLookup: peerLookup, + eg: eg, + ctx: ctx, + senders: make(map[peerid.PeerID]*peerSender), + abandonedPeers: make(map[peerid.PeerID]bool), + } +} + +// Run runs until the PeerClient experiences a terminal error or is shutdown +func (p *PeerClient) Run() error { + return p.eg.Wait() +} + +var ErrResetConnection = errors.New("connection must be reset") +var errAbandonPeer = errors.New("peer connect timed out") + +// Send a message to a peer. +// +// ErrResetConnection may be returned if we cannot deliver messages to this peer. In this case +// messages will be dropped and the caller must reset their peer session. +func (p *PeerClient) Send(msg *pb.PeerMessage) error { + peerID, err := peerid.Make(msg.PeerId) + if err != nil { + return err + } + sender, err := p.getOrCreateSender(msg, peerID) + if err != nil { + return err + } + return sender.queueMessage(msg) +} + +// getOrCreateSender returns the existing peerSender for the peerID, or creates a new one if it doesn't exist +func (p *PeerClient) getOrCreateSender(msg *pb.PeerMessage, peerID peerid.PeerID) (*peerSender, error) { + p.sendersMu.Lock() + defer p.sendersMu.Unlock() + sender, ok := p.senders[peerID] + if ok { + return sender, nil + } + + // check if we've had a sender for this peer in the past + if p.abandonedPeers[peerID] { + // This is a peer we previously decided to abandon. + if !isEstablishing(msg) { + // We can resume talking to it, but any new communication + // must first establish a new enclave connection + logger.Warnw("attempting to send non-establishing message to previously abandoned peer", + "peerID", peerID) + return nil, ErrResetConnection + } + + // otherwise, we can create a new connection for it it + delete(p.abandonedPeers, peerID) + logger.Infow("attempting to reconnect to previously abandoned peer", "peerID", peerID) + } else { + logger.Infow("creating new peerSender on first message to peer", "peerID", peerID) + } + + sender = newPeerSender(p.me, peerID, p.peerLookup, p.cfg) + p.senders[peerID] = sender + metrics.SetGauge(activeConnectionsGauge, float32(len(p.senders))) + p.eg.Go(func() error { + err := sender.run(p.ctx) + + // Remove the sender from the sender's map. + // Note: There's a harmless race here. If a Send caller has already retrieved + // their sender and is in the midst of calling queueMessage when the sender + // exits, the message will never be processed. This is fine, because we want + // to drop old messages anyway. Because queueMessage never blocks, there's no + // deadlock concern either. + p.sendersMu.Lock() + defer p.sendersMu.Unlock() + + // a subsequent send will need to create a new sender + delete(p.senders, peerID) + + if errors.Is(err, errAbandonPeer) { + // remember if we gave up on this peer, so we know to reset our connect + // if we communicate with them again. + p.abandonedPeers[peerID] = true + + // not a fatal error + return nil + } + + return err + }) + + return sender, nil + +} + +func isEstablishing(msg *pb.PeerMessage) bool { + switch msg.Inner.(type) { + case *pb.PeerMessage_Syn, *pb.PeerMessage_Synack, *pb.PeerMessage_Rst: + return true + default: + return false + } +} + +// peerSender handles PeerMessages for one particular peer +// +// peerSenders try to re-connect to the remote peer on errors +// and only stops running on unrecoverable errors +type peerSender struct { + cfg *config.PeerConfig // client configuration + me peerid.PeerID // the sending local peer + remote peerid.PeerID // the targeted remote peer + peerLookup PeerLookup // name resolution for peers + pending []*pb.PeerConnectionData // requests that might be resent + lastAck sequenceNumber // lastAck + 1 should be always be pending[0]'s sequence number + tx atomic.Pointer[chan *pb.PeerMessage] // on epoch bumps, old messages can be discarded so the send channel is atomically replaced + labels []metrics.Label // metric labels to attach to metrics from this sender +} + +func newPeerSender( + me, remote peerid.PeerID, + peerLookup PeerLookup, + cfg *config.PeerConfig) *peerSender { + s := &peerSender{ + cfg: cfg, + me: me, + remote: remote, + peerLookup: peerLookup, + pending: nil, + tx: atomic.Pointer[chan *pb.PeerMessage]{}, + labels: []metrics.Label{{ + Name: "peerID", + Value: remote.String(), + }}, + } + c := make(chan *pb.PeerMessage, cfg.BufferSize) + s.tx.Store(&c) + return s +} + +func (p *peerSender) run(ctx context.Context) error { + + peerAddr, err := p.lookupPeerAddr(ctx) + if err != nil { + // this peer has gone away + logger.Warnw("could not lookup peer: %v, giving up", err) + return errAbandonPeer + } + + lastConnect := time.Now() + sleepTime := time.Duration(0) + for { + if time.Since(lastConnect) > p.cfg.AbandonDuration { + // we've been trying to connect to this peer for long enough, give up + metrics.IncrCounterWithLabels(abandonedPeerCounter, 1, p.labels) + return errAbandonPeer + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(sleepTime): + } + + // attempt to connect to peer + metrics.IncrCounterWithLabels(connectAttemptCounter, 1, p.labels) + var d net.Dialer + conn, err := d.DialContext(ctx, "tcp", peerAddr) + if err != nil { + logger.Infow("Failed to connect to peer", + "peerID", p.remote, + "peer", peerAddr, + "err", err) + sleepTime = util.Clamp(sleepTime*2, p.cfg.MinSleepDuration, p.cfg.MaxSleepDuration) + continue + } + + // once connected, actually kickoff the sender + done := make(chan error, 1) + start := time.Now() + go func() { done <- p.handleConnection(ctx, conn) }() + select { + case err := <-done: + // error case, retry + duration := time.Since(start) + logger.Infow("Peer connection terminated", + "peerID", p.remote, + "peer", peerAddr, + "err", err, + "connected_duration", duration) + + // we don't want to hammer the peer if it is failing right away, + // but we don't need to sleep max time if we've been connected for + // a while. subtract out the amount of time we've been running for + sleepTime = util.Clamp(sleepTime*2-duration, p.cfg.MinSleepDuration, p.cfg.MaxSleepDuration) + + var handshakeErr *errFailedHandshake + if !errors.As(err, &handshakeErr) { + // If we managed to actually get to send some data to the peer + // restart the timer on detecting dead peers. + lastConnect = time.Now() + } + continue + case <-ctx.Done(): + // externally closed, close the connection and wait for handle to finish before exiting + conn.Close() + <-done + return ctx.Err() + } + } +} + +func (p *peerSender) lookupPeerAddr(ctx context.Context) (string, error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + return util.RetrySupplierWithBackoff(ctx, func() (string, error) { + // attempt to get the peer's host name + addr, err := p.peerLookup.Lookup(ctx, p.remote) + if err != nil { + logger.Warnw("Failed to lookup peer", "peerID", p.remote, "err", err) + return "", err + } + if addr == nil { + // A peer that has been added to the raft group is not known to the host. + // It was probably wiped out and it's redis entry was removed, mark it as + // abandoned. + logger.Warnw("Remote peer does not exist", "peerID", p.remote) + cancel() + return "", errors.New("remote peer does not exist") + } + return *addr, nil + }, p.cfg.MinSleepDuration, p.cfg.MaxSleepDuration) +} + +// queue a message to be sent by the run loop +func (p *peerSender) queueMessage(msg *pb.PeerMessage) error { + var c chan *pb.PeerMessage + if isEstablishing(msg) { + // these messages indicate we don't care about previous messages, we can replace our + // send channel with a fresh one + c = make(chan *pb.PeerMessage, p.cfg.BufferSize) + close(*p.tx.Swap(&c)) + } else { + c = *p.tx.Load() + } + + metrics.SetGaugeWithLabels(outboundQueueLengthGauge, float32(len(c)), p.labels) + + select { + case c <- msg: + return nil + default: + return ErrResetConnection + } +} + +// processAck drops acknowledged pending messages +func (p *peerSender) processAck(ack sequenceNumber) error { + if ack.cmp(p.lastAck) < 0 { + // this peer is buggy + return fmt.Errorf("remote peer illegal ack %v, must be at least %v", ack, p.lastAck) + } + + // drop any pending requests that have already been acked + for len(p.pending) > 0 { + seqno, err := makeSeqno(p.pending[0].Seqno) + if err != nil { + return err + } + if seqno.cmp(ack) > 0 { + break + } + p.pending = p.pending[1:] + } + p.lastAck = ack + return nil +} + +type ackResult struct { + ack sequenceNumber + err error +} + +// ackLoop forwards acks read from a connection until it is cancelled or hits an error +func (p *peerSender) ackLoop(ctx context.Context, r *bufio.Reader, ackOut chan ackResult) { + for { + ack, err := readAck(r) + select { + case ackOut <- ackResult{ack, err}: + if err != nil { + return + } + case <-ctx.Done(): + return + } + } +} + +type errFailedHandshake struct{ reason error } + +func (e *errFailedHandshake) Error() string { return fmt.Sprintf("failed handshake: %v", e.reason) } + +func (p *peerSender) handleConnection(ctx context.Context, conn net.Conn) error { + defer conn.Close() + + logger := logger.With("peer", conn.RemoteAddr(), "peerID", p.remote) + + // handshake + logger.Infow("writing hello") + if err := writeHello(conn, p.me, p.remote); err != nil { + return &errFailedHandshake{err} + } + reader := bufio.NewReader(conn) + lastAck, err := readHelloAck(reader) + if err != nil { + return &errFailedHandshake{err} + } + + // find out which messages haven't been received by the remote peer + if err := p.processAck(lastAck); err != nil { + return &errFailedHandshake{err} + } + logger.Infow("resending pending messages on connect", + "last_ack", lastAck, + "pending", len(p.pending)) + + // resend any unacked messages + currentSeqno := p.lastAck.next() + + for _, msg := range p.pending { + metrics.IncrCounterWithLabels(resendCounter, 1, p.labels) + err := writeFramed(conn, &pb.PeerConnectionMessage{Inner: &pb.PeerConnectionMessage_Data{Data: msg}}) + if err != nil { + return fmt.Errorf("resend pending: %w", err) + } + currentSeqno = currentSeqno.next() + } + + // goroutine to read ack responses and send up acks + ackCtx, ackCancel := context.WithCancel(ctx) + ackChan := make(chan ackResult) + go p.ackLoop(ackCtx, reader, ackChan) + // once we're done, cancel the ack reader + defer ackCancel() + + for { + + msgChan := *p.tx.Load() + + // process new sends / listen for acks + select { + case msg := <-msgChan: + metrics.SetGaugeWithLabels(outboundQueueLengthGauge, float32(len(msgChan)), p.labels) + + // Check if channel was closed; if so, ignore. + if msg == nil { + continue + } + metrics.IncrCounterWithLabels(sendCounter, 1, p.labels) + + switch msg.Inner.(type) { + case *pb.PeerMessage_Syn, *pb.PeerMessage_Synack: + // bump our epoch, reset num to 0 + currentSeqno = currentSeqno.nextEpoch() + metrics.IncrCounterWithLabels(epochCounter, 1, p.labels) + + // drop our pending queue (should be for previous epoch) + // Note: if this is a Syn, it's possible we have a pending Rst to the remote + // peer, and it could get dropped. This is fine though because we already know + // a Syn is going (we're sending it right now) + p.pending = nil + } + + logger.Debugw("got peermessage from enclave to send to peer", + "seqno", currentSeqno, + "type", fmt.Sprintf("%T", msg.Inner)) + + pcd := pb.PeerConnectionData{ + Msg: msg, + Seqno: currentSeqno.proto(), + } + currentSeqno = currentSeqno.next() + + // send the message and save it for later resending + p.pending = append(p.pending, &pcd) + err := writeFramed(conn, &pb.PeerConnectionMessage{Inner: &pb.PeerConnectionMessage_Data{Data: &pcd}}) + if err != nil { + return fmt.Errorf("send data: %w", err) + } + case ack := <-ackChan: + if ack.err != nil { + return ack.err + } + metrics.IncrCounterWithLabels(ackCounter, 1, p.labels) + logger.Debugw("got ack from peer", "seqno", ack.ack) + // dispose of everything we know has been acked by the peer + if err := p.processAck(ack.ack); err != nil { + return err + } + case <-ctx.Done(): + return ctx.Err() + } + } +} diff --git a/host/peer/client_test.go b/host/peer/client_test.go new file mode 100644 index 0000000..78c7a15 --- /dev/null +++ b/host/peer/client_test.go @@ -0,0 +1,664 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package peer + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "log" + "net" + "reflect" + "testing" + "time" + + "github.com/signalapp/svr2/config" + "github.com/signalapp/svr2/peerid" + "github.com/signalapp/svr2/servicetest" + + pb "github.com/signalapp/svr2/proto" +) + +// mapPeerDB is a static mapping of peers initialized at startup +type mapPeerDB struct { + m map[peerid.PeerID]string +} + +// Implements PeerLookup +func (m *mapPeerDB) Lookup(ctx context.Context, peer peerid.PeerID) (*string, error) { + val, ok := m.m[peer] + if !ok { + return nil, nil + } + return &val, nil +} + +type clientFixture struct { + data []byte + r *bufio.Reader + server net.Conn + client net.Conn + done chan error +} + +// start initializers a server/client connection, provides it to the sender, +// and allows subsequent calls to receive messages from the server connection +func (c *clientFixture) start(t *testing.T, sender *peerSender) { + c.startFrom(t, sender, sequenceNumber{}) +} + +func (c *clientFixture) startFrom(t *testing.T, sender *peerSender, helloAck sequenceNumber) { + c.server, c.client = net.Pipe() + c.r = bufio.NewReader(c.server) + + c.done = make(chan error) + go func() { c.done <- sender.handleConnection(context.Background(), c.client) }() + + if from, to := c.readHello(t); from != clientID { + t.Errorf("client hello fromPeerID %v, want %v", from, clientID) + } else if to != serverID { + t.Errorf("client hello toPeerID %v, want %v", to, serverID) + } + + f.writeHelloAck(t, helloAck) +} + +func (c *clientFixture) close() error { + c.server.Close() + return <-c.done +} + +func (c *clientFixture) readHello(t *testing.T) (to, from peerid.PeerID) { + from, to, err := readHello(c.r) + if err != nil { + t.Fatalf("fReadHello: %v", err) + } + return from, to +} + +func (c *clientFixture) writeHelloAck(t *testing.T, seqno sequenceNumber) { + if err := writeHelloAck(c.server, seqno); err != nil { + t.Error(err) + } +} + +func (c *clientFixture) writeAck(t *testing.T, seqno sequenceNumber) { + if err := writeAck(c.server, seqno); err != nil { + t.Error(err) + } +} + +func (c *clientFixture) sendSyn(t *testing.T, s *peerSender) { + c.queueMessage(t, s, c.peerSynMessage()) +} + +func (c *clientFixture) sendData(t *testing.T, s *peerSender) { + c.queueMessage(t, s, c.peerDataMessage()) +} + +func (c *clientFixture) queueMessage(t *testing.T, s *peerSender, msg *pb.PeerMessage) { + if err := s.queueMessage(msg); err != nil { + t.Error(err) + } +} + +func (c *clientFixture) expectMessage(t *testing.T, expectedSeqno sequenceNumber) { + m, err := readFramed(c.r) + if err != nil { + t.Error(err) + return + } + actual, err := makeSeqno(m.GetData().Seqno) + if err != nil { + t.Error(err) + return + } + if actual != expectedSeqno { + t.Errorf("message seqno=%v, want %v", actual, expectedSeqno) + return + } + var bs []byte + switch v := m.GetData().GetMsg().Inner.(type) { + case *pb.PeerMessage_Syn: + bs = v.Syn + case *pb.PeerMessage_Data: + bs = v.Data + default: + t.Errorf("unexpected message") + return + } + if !bytes.Equal(bs, c.data) { + t.Errorf("message data=%v, want %v", bs, c.data) + } +} + +func (c *clientFixture) sequenced(seqno sequenceNumber) *pb.PeerConnectionData { + return &pb.PeerConnectionData{ + Msg: &pb.PeerMessage{ + Inner: &pb.PeerMessage_Syn{Syn: c.data}, + }, + Seqno: seqno.proto(), + } +} + +func (c *clientFixture) peerSynMessage() *pb.PeerMessage { + return &pb.PeerMessage{ + Inner: &pb.PeerMessage_Syn{Syn: c.data}, + PeerId: serverID[:], + } +} + +func (c *clientFixture) peerDataMessage() *pb.PeerMessage { + return &pb.PeerMessage{ + Inner: &pb.PeerMessage_Data{Data: c.data}, + PeerId: serverID[:], + } +} + +type SenderOption func(*peerSender) + +func withBuffer(bufSize int) SenderOption { + return func(sender *peerSender) { + sender.cfg.BufferSize = bufSize + c := make(chan *pb.PeerMessage, bufSize) + sender.tx.Store(&c) + } +} + +func withLookup(peerID peerid.PeerID, addr string) SenderOption { + return func(sender *peerSender) { + sender.peerLookup = &mapPeerDB{map[peerid.PeerID]string{peerID: addr}} + } +} + +func createSender(options ...SenderOption) *peerSender { + cfg := config.Default() + sender := newPeerSender(clientID, serverID, &mapPeerDB{}, &cfg.Peer) + for _, opt := range options { + opt(sender) + } + return sender +} + +var ( + clientID = peer(0) + serverID = peer(1) + f = clientFixture{data: []byte("data")} +) + +func (*clientFixture) seq(s uint64) sequenceNumber { + return sequenceNumber{ + seq: s, + epoch: 0, + } +} + +func TestProcessAck(t *testing.T) { + + orig := []*pb.PeerConnectionData{ + f.sequenced(f.seq(1)), + f.sequenced(f.seq(2)), + f.sequenced(f.seq(3)), + } + + tests := []struct { + ack sequenceNumber + expected []*pb.PeerConnectionData + }{ + {f.seq(0), orig}, + {f.seq(1), orig[1:]}, + {f.seq(2), orig[2:]}, + {f.seq(3), []*pb.PeerConnectionData{}}, + // acks past max sent, should drop everything + {f.seq(4), []*pb.PeerConnectionData{}}, + {f.seq(100), []*pb.PeerConnectionData{}}, + } + for _, tt := range tests { + t.Run(fmt.Sprintf("ack=%v", tt.ack), func(t *testing.T) { + sender := peerSender{pending: orig} + err := sender.processAck(tt.ack) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(sender.pending, tt.expected) { + t.Errorf("processAck(%v)=%v, want %v", tt.ack, sender.pending, tt.expected) + } + }) + } +} + +func TestHandleConnection(t *testing.T) { + sender := createSender() + + f.start(t, sender) + f.sendSyn(t, sender) + f.expectMessage(t, sequenceNumber{1, 0}) + if sender.lastAck != (sequenceNumber{}) { + t.Errorf("sender lastAck=%v, want %v", sender.lastAck, 0) + } + + f.writeAck(t, sequenceNumber{1, 0}) + f.close() + + if sender.lastAck != (sequenceNumber{1, 0}) { + t.Errorf("sender lastAck=%v, want %v", sender.lastAck, 1) + } +} + +func TestResendPending(t *testing.T) { + for _, i := range []uint64{3, 4, 5, 8, 9} { + t.Run(fmt.Sprintf("helloAck%v", i), func(t *testing.T) { + f.resendPendingHelper(t, f.seq(i)) + }) + } +} + +func (c *clientFixture) resendPendingHelper(t *testing.T, helloAck sequenceNumber) { + sender := createSender() + + f.start(t, sender) + + for i := 1; i < 10; i++ { + f.sendData(t, sender) + f.expectMessage(t, f.seq(uint64(i))) + } + + // lie and only ack id 3 + f.writeAck(t, f.seq(3)) + t.Logf("done reason: %v", f.close()) + + // restart, should resend from helloAck up to 10 + f.startFrom(t, sender, helloAck) + + for i := helloAck.seq + 1; i < 10; i++ { + f.expectMessage(t, f.seq(i)) + } + + // should be able to send a new message at seqno=10 + f.sendData(t, sender) + f.expectMessage(t, f.seq(10)) + + t.Logf("done reason: %v", f.close()) +} + +func TestDropOldEpoch(t *testing.T) { + sender := createSender() + + f.start(t, sender) + + // these should not be resent on disconnect + f.sendSyn(t, sender) + f.expectMessage(t, sequenceNumber{epoch: 1, seq: 0}) + f.sendData(t, sender) + f.expectMessage(t, sequenceNumber{epoch: 1, seq: 1}) + + // SYN should start a new epoch + f.sendSyn(t, sender) + f.expectMessage(t, sequenceNumber{epoch: 2, seq: 0}) + f.sendData(t, sender) + f.expectMessage(t, sequenceNumber{epoch: 2, seq: 1}) + + // should resend the Syn and first message + t.Logf("done reason: %v", f.close()) + f.startFrom(t, sender, sequenceNumber{0, 1}) + f.expectMessage(t, sequenceNumber{epoch: 2, seq: 0}) + f.expectMessage(t, sequenceNumber{epoch: 2, seq: 1}) + + t.Logf("done reason: %v", f.close()) + +} + +func TestManySyncs(t *testing.T) { + sender := createSender() + + f.start(t, sender) + + // these should not be resent on disconnect + for i := uint32(1); i <= 5; i++ { + f.sendSyn(t, sender) + f.expectMessage(t, sequenceNumber{epoch: i, seq: 0}) + } + + // should only send the latest epoch + t.Logf("done reason: %v", f.close()) + f.startFrom(t, sender, sequenceNumber{0, 1}) + f.expectMessage(t, sequenceNumber{epoch: 5, seq: 0}) + t.Logf("done reason: %v", f.close()) + +} + +func TestBufferLimit(t *testing.T) { + sender := createSender(withBuffer(10)) + + f.start(t, sender) + + // first 10 sends should be fine + for i := 0; i < 10; i++ { + f.sendData(t, sender) + } + + // 11th send we're out of space + if err := sender.queueMessage(f.peerDataMessage()); err != ErrResetConnection { + t.Fatalf("got %v, want %v", err, ErrResetConnection) + } + + // send a Syn, which should work and free up our buffer for 10 more sends + f.sendSyn(t, sender) + for i := 0; i < 9; i++ { + f.sendData(t, sender) + } +} + +func TestAbandonPeer(t *testing.T) { + ln, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + lookup := &mapPeerDB{map[peerid.PeerID]string{serverID: ln.Addr().String()}} + + // no longer listening on address + ln.Close() + + peerClient := NewPeerClient(context.Background(), clientID, lookup, &config.PeerConfig{ + AbandonDuration: time.Millisecond * 10, + BufferSize: 1000, + }) + + // eventually these should fail + for i := 0; i < 10; i++ { + err = peerClient.Send(&pb.PeerMessage{ + PeerId: serverID[:], + Inner: &pb.PeerMessage_Data{}, + }) + if err != nil { + break + } + time.Sleep(time.Millisecond * 5) + } + + if !errors.Is(err, ErrResetConnection) { + t.Fatalf("got %v, want %v", err, ErrResetConnection) + } + + if _, ok := peerClient.abandonedPeers[serverID]; !ok { + t.Fatal("peer should be marked abandoned") + } + +} + +func TestAbandonZombiePeer(t *testing.T) { + ln, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + lookup := &mapPeerDB{map[peerid.PeerID]string{serverID: ln.Addr().String()}} + + peerClient := NewPeerClient(context.Background(), clientID, lookup, &config.PeerConfig{ + AbandonDuration: time.Millisecond * 10, + BufferSize: 1000, + }) + + // server accepts a connection and reads a hello, but then terminates + serverDone := make(chan error, 1) + go func() { + for { + c, err := ln.Accept() + if err != nil { + close(serverDone) + return + } + r := bufio.NewReader(c) + if _, _, err = readHello(r); err != nil { + serverDone <- err + return + } + c.Close() + } + }() + + // eventually these should fail + for i := 0; i < 10; i++ { + err = peerClient.Send(&pb.PeerMessage{ + PeerId: serverID[:], + Inner: &pb.PeerMessage_Data{}, + }) + if err != nil { + break + } + time.Sleep(time.Millisecond * 5) + } + + if !errors.Is(err, ErrResetConnection) { + t.Fatalf("got %v, want %v", err, ErrResetConnection) + } + + if _, ok := peerClient.abandonedPeers[serverID]; !ok { + t.Fatal("peer should be marked abandoned") + } + ln.Close() + if err := <-serverDone; err != nil { + t.Fatal(err) + } + +} + +func TestProdigalPeer(t *testing.T) { + ln, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + lookup := &mapPeerDB{map[peerid.PeerID]string{serverID: ln.Addr().String()}} + peerClient := NewPeerClient(context.Background(), clientID, lookup, &config.PeerConfig{ + AbandonDuration: time.Minute, + BufferSize: 1, + }) + + peerClient.abandonedPeers[serverID] = true + + dataMsg := &pb.PeerMessage{ + PeerId: serverID[:], + Inner: &pb.PeerMessage_Data{}, + } + + // a data message should be outright rejected + if err := peerClient.Send(dataMsg); !errors.Is(err, ErrResetConnection) { + t.Fatalf("Send(data message) = %v, want %v", err, ErrResetConnection) + } + + serverDone := make(chan error, 1) + server := errorServer{seqno: sequenceNumber{1, 9}, id: serverID, ln: ln} + go func() { serverDone <- server.run(fsOk, 1) }() + + // should allow an RST message + rstMsg := &pb.PeerMessage{ + PeerId: serverID[:], + Inner: &pb.PeerMessage_Rst{Rst: true}, + } + if err = peerClient.Send(rstMsg); err != nil { + t.Fatalf("Send(rst) = %v, want %v", err, nil) + } + + if err := <-serverDone; err != nil { + t.Fatalf("server failed with: %v", err) + } + + if expected := (sequenceNumber{1, 10}); server.seqno != expected { + t.Fatalf("server recieved seqno=%v, want %v", server.seqno, expected) + } +} + +func TestPeerLookupFails(t *testing.T) { + // peer won't be found + lookup := &mapPeerDB{map[peerid.PeerID]string{}} + + dataMsg := &pb.PeerMessage{ + PeerId: serverID[:], + Inner: &pb.PeerMessage_Data{}, + } + + peerClient := NewPeerClient(context.Background(), clientID, lookup, &config.PeerConfig{ + AbandonDuration: time.Minute, + BufferSize: 1000, + }) + + if err := peerClient.Send(dataMsg); err != nil { + t.Fatalf("first send failed: %v", err) + } + _, err := servicetest.RetryFun(time.Second*5, func() (interface{}, error) { + err := peerClient.Send(dataMsg) + if errors.Is(err, errAbandonPeer) { + return nil, errors.New("should eventually fail") + } + return nil, nil + }) + if err != nil { + t.Fatal("send never failed") + } + +} + +func TestRetryErrors(t *testing.T) { + + tests := []struct { + name string + // where to error on first server run + fs failureStage + // number of messages to read on first server run + expectedFirstRun int + // number of messages to read on second server run + expectedSecondRun int + }{ + {"hello", fsHello, 0, 2}, + {"helloAck", fsHelloAck, 0, 2}, + {"receive", fsReceive, 0, 2}, + {"ack", fsAck, 2, 0}, + {"ok", fsOk, 2, 0}, + {"ok split", fsOk, 1, 1}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ln, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + errorServer := errorServer{id: serverID, ln: ln} + serverDone := make(chan error) + go func() { + // run error server, erroring when configured + if err := errorServer.run(tt.fs, tt.expectedFirstRun); err != nil { + log.Println(err) + t.Error(err) + } + + t.Logf("finished first run") + + // restart errorServer but always succeed + err := errorServer.run(fsOk, tt.expectedSecondRun) + ln.Close() + serverDone <- err + }() + + // send two messages to the error server + sender := createSender( + withLookup(serverID, ln.Addr().String()), + withBuffer(2)) + + senderDone := make(chan error) + ctx, cancel := context.WithCancel(context.Background()) + go func() { senderDone <- sender.run(ctx) }() + f.sendSyn(t, sender) + f.sendData(t, sender) + + // wait for server to read our messages + select { + case err := <-senderDone: + t.Fatal(err) + case err := <-serverDone: + if err != nil { + t.Error(err) + } + } + cancel() + t.Logf("finished %v", <-senderDone) + }) + } + +} + +type failureStage = int + +const ( + fsHello failureStage = iota + fsHelloAck + fsReceive + fsAck + fsOk +) + +type errorServer struct { + seqno sequenceNumber + id peerid.PeerID + ln net.Listener +} + +func (e *errorServer) run(failure failureStage, expectedMessages int) error { + c, err := e.ln.Accept() + if err != nil { + return err + } + log.Printf("got connection local=%v, remote=%v", c.LocalAddr().String(), c.RemoteAddr().String()) + + defer c.Close() + + r := bufio.NewReader(c) + + if failure == fsHello { + return nil + } + + if _, _, err = readHello(r); err != nil { + return err + } + + if failure == fsHelloAck { + return nil + } + + if err := writeHelloAck(c, e.seqno); err != nil { + return err + } + + if failure == fsReceive { + return nil + } + + // receive expectedMessages messages + for i := 0; i < expectedMessages; i++ { + msg, err := readFramed(r) + if err != nil { + return err + } + e.seqno, err = makeSeqno(msg.GetData().Seqno) + if err != nil { + return err + } + + } + + if failure == fsAck { + return nil + } + + if err := writeAck(c, e.seqno); err != nil { + return err + } + return nil +} + +func peer(i byte) peerid.PeerID { + return [32]byte{i} +} diff --git a/host/peer/peer.go b/host/peer/peer.go new file mode 100644 index 0000000..2d1a8ca --- /dev/null +++ b/host/peer/peer.go @@ -0,0 +1,18 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +// Package peer implements the host to host network protocol for sending messages between enclaves +package peer + +import ( + pb "github.com/signalapp/svr2/proto" +) + +var ( + maxMessageLength uint64 = 1024 * 1024 * 128 +) + +type EnclaveSender interface { + // Send sends a message to the enclave and potentially waits for a reply + Send(p *pb.UntrustedMessage) (*pb.EnclaveMessage, error) +} diff --git a/host/peer/peerdb/peerdb.go b/host/peer/peerdb/peerdb.go new file mode 100644 index 0000000..39fa49d --- /dev/null +++ b/host/peer/peerdb/peerdb.go @@ -0,0 +1,217 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package peerdb + +import ( + "context" + "encoding/hex" + "errors" + "fmt" + "sort" + "strings" + "sync" + "time" + + "github.com/go-redis/redis/v8" + "github.com/signalapp/svr2/config" + "github.com/signalapp/svr2/logger" + "github.com/signalapp/svr2/peerid" + "github.com/signalapp/svr2/util" + "google.golang.org/protobuf/encoding/protojson" + + pb "github.com/signalapp/svr2/proto" +) + +// PeerDB associates the enclave level PeerID with the external hostname of the peer +type PeerDB struct { + cfg config.RedisConfig + rdb *redis.ClusterClient + peersKeyPrefix string // key prefix for peer values. You can't set TTLs on individual entries in a hash, so we set all peers as individual values. + createRaftKeyName string // key used for to exclusivity for raft group creation + clock util.Clock +} + +// New creates a PeerDB +func New(cfg config.RedisConfig) *PeerDB { + rdb := redis.NewClusterClient(&redis.ClusterOptions{ + Addrs: cfg.Addrs, + Password: cfg.Password, + }) + return &PeerDB{ + cfg: cfg, + rdb: rdb, + peersKeyPrefix: fmt.Sprintf("%s::%s::", cfg.Name, "peers"), + createRaftKeyName: fmt.Sprintf("%s::%s", cfg.Name, "create"), + clock: util.RealClock, + } +} + +func (p *PeerDB) peerKey(peer peerid.PeerID) string { + return p.peersKeyPrefix + hex.EncodeToString(peer[:]) +} + +func (p *PeerDB) peerFromKey(key string) (peerid.PeerID, error) { + if !strings.HasPrefix(key, p.peersKeyPrefix) { + return peerid.PeerID{}, fmt.Errorf("key does not have peer key prefix") + } else if bytes, err := hex.DecodeString(key[len(p.peersKeyPrefix):]); err != nil { + return peerid.PeerID{}, fmt.Errorf("hex-decoding key: %v", err) + } else { + return peerid.Make(bytes) + } +} + +// Lookup fetches a hostname from the PeerDB from the hostname +func (p *PeerDB) Lookup(ctx context.Context, peer peerid.PeerID) (*string, error) { + bs, err := p.rdb.Get(ctx, p.peerKey(peer)).Bytes() + if err == redis.Nil { + return nil, nil + } + peerEntry := pb.PeerEntry{} + if err := protojson.Unmarshal(bs, &peerEntry); err != nil { + return nil, err + } + return &peerEntry.Addr, nil +} + +// Close should be called when the PeerDB is no longer used +func (p *PeerDB) Close() error { + return p.rdb.Close() +} + +// Insert adds a peer entry keyed by peerID to the PeerDB +// +// This method retries until the insert succeeds or the provided context is cancelled +func (p *PeerDB) Insert(ctx context.Context, me peerid.PeerID, addr string, ttl time.Duration) error { + return p.insert(ctx, me, addr, false, ttl) +} + +// JoinedRaft updates a peer entry to indicate that it has joined the raft cluster +// +// This method retries until the insert succeeds or the provided context is cancelled +func (p *PeerDB) JoinedRaft(ctx context.Context, me peerid.PeerID, addr string, ttl time.Duration) error { + return p.insert(ctx, me, addr, true, ttl) +} + +func (p *PeerDB) insert(ctx context.Context, me peerid.PeerID, addr string, isRaftMember bool, ttl time.Duration) error { + return util.RetryWithBackoff(ctx, func() error { + logger.Debugw("Attempting to add self to peerdb", "addr", addr, "raftmember", isRaftMember) + currentTime := p.clock.Now().Unix() + m := pb.PeerEntry{Addr: addr, LastUpdateTs: currentTime, RaftMember: isRaftMember} + if isRaftMember { + m.JoinTs = currentTime + } + bs, err := protojson.Marshal(&m) + if err != nil { + return err + } + if err := p.rdb.Set(ctx, p.peerKey(me), bs, ttl).Err(); err != nil { + logger.Errorw("rdb.Set for PeerDB insert error", "err", err) + return err + } + return nil + }, p.cfg.MinSleepDuration, p.cfg.MaxSleepDuration) + +} + +// FindRaftMember returns a member of an existing raft group that may be used to join raft +// If no such eligible members exist, this method may return `me`, indicating that it is safe +// to create a raft group instead of joining one. +func (p *PeerDB) FindRaftMember(ctx context.Context, me peerid.PeerID, localPeerAddr string) (peerid.PeerID, error) { + // retry until we find an eligible peer or we acquire the exclusive creation lock + return util.RetrySupplierWithBackoff(ctx, func() (peerid.PeerID, error) { + peers, err := p.list(ctx) + if err != nil { + logger.Infow("failed to fetch raft members", "err", err) + return peerid.PeerID{}, err + } + + var peerIDs []peerid.PeerID + for k, v := range peers { + if k == me || v.Addr == localPeerAddr { + continue + } + if !v.RaftMember { + continue + } + peerIDs = append(peerIDs, k) + } + + if len(peerIDs) == 0 { + logger.Infow("no available raft members, attempting to get creation lock") + if err := p.acquireCreationLock(ctx, me); err != nil { + // someone else probably got the lock, so our next attempt may go better + logger.Infow("failed to get creation lock", "err", err) + return peerid.PeerID{}, errors.New("no peers available and could not get creation lock") + } + return me, nil + } + // sort so the most recently joined member is first + sort.Slice(peerIDs, func(i int, j int) bool { + return peers[peerIDs[j]].JoinTs < peers[peerIDs[i]].JoinTs + }) + logger.Infow("found joinable raft peer", "peerID", peerIDs[0]) + return peerIDs[0], nil + + }, p.cfg.MinSleepDuration, p.cfg.MaxSleepDuration) +} + +// acquireCreationLock attempts to acquire an exclusive lock to create a raft group +// on success, this node may create a new raft group +// on error, this node should re-attempt to join from a peer +func (p *PeerDB) acquireCreationLock(ctx context.Context, me peerid.PeerID) error { + got, err := p.rdb.SetNX(ctx, p.createRaftKeyName, me[:], 0).Result() + if err != nil { + return err + } + if !got { + return errors.New("failure to get exclusive creation lock") + } + return nil +} + +// list fetches all the peers in the database +func (p *PeerDB) list(ctx context.Context) (map[peerid.PeerID]*pb.PeerEntry, error) { + var mu sync.Mutex + var shardResults [][]string + + err := p.rdb.ForEachShard(ctx, func(ctx context.Context, shard *redis.Client) error { + mu.Lock() + defer mu.Unlock() + keys, err := shard.Keys(ctx, p.peersKeyPrefix+"*").Result() + if err != nil { + return err + } + logger.Debugf("Retrieved %v peers from peerdb shard", len(keys)) + shardResults = append(shardResults, keys) + return nil + }) + + if err != nil { + return nil, err + } + + ret := make(map[peerid.PeerID]*pb.PeerEntry) + for _, keys := range shardResults { + for _, key := range keys { + peerID, err := p.peerFromKey(key) + if err != nil { + return nil, fmt.Errorf("invalid peer key: %v", key) + } + v, err := p.rdb.Get(ctx, key).Result() + if err == redis.Nil { + // Key expired since call to Keys + continue + } else if err != nil { + return nil, fmt.Errorf("unable to get peer key %v: %v", key, err) + } + peerEntry := &pb.PeerEntry{} + if err := protojson.Unmarshal([]byte(v), peerEntry); err != nil { + return nil, err + } + ret[peerID] = peerEntry + } + } + logger.Infof("Retrieved %v peers from peerdb", len(ret)) + return ret, nil +} diff --git a/host/peer/peerdb/peerdb_test.go b/host/peer/peerdb/peerdb_test.go new file mode 100644 index 0000000..cdb8ae4 --- /dev/null +++ b/host/peer/peerdb/peerdb_test.go @@ -0,0 +1,183 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package peerdb + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/signalapp/svr2/config" + "github.com/signalapp/svr2/peerid" + "github.com/signalapp/svr2/util" +) + +func TestPeerDB(t *testing.T) { + ctx := context.Background() + s := miniredis.RunT(t) + peerdb := New(config.RedisConfig{ + Addrs: []string{s.Addr()}, + Name: "test", + }) + defer peerdb.Close() + + peer1 := [32]byte{1} + peer2 := [32]byte{2} + if err := peerdb.Insert(ctx, peer1, "host1", time.Minute); err != nil { + t.Fatal(err) + } + if err := peerdb.Insert(ctx, peer2, "host2", time.Minute); err != nil { + t.Fatal(err) + } + + host1, err := peerdb.Lookup(context.Background(), peer1) + if err != nil { + t.Error(err) + } + if host1 == nil || *host1 != "host1" { + t.Errorf("Lookup(%v)=%v, want %v", peer1, host1, "host1") + } + + host2, err := peerdb.Lookup(context.Background(), peer2) + if err != nil { + t.Error(err) + } + if host2 == nil || *host2 != "host2" { + t.Errorf("Lookup(%v)=%v, want %v", peer1, host1, "host2") + } +} + +func TestMissingPeer(t *testing.T) { + s := miniredis.RunT(t) + peerdb := New(config.RedisConfig{ + Addrs: []string{s.Addr()}, + Name: "test", + }) + defer peerdb.Close() + + peer1 := [32]byte{1} + v, err := peerdb.Lookup(context.Background(), peer1) + if err != nil { + t.Errorf("missing peer shouldn't error: %v", err) + } + if v != nil { + t.Errorf("Lookup(%v)=%v, want nil", peer1, v) + } + +} + +type EntryStatus int + +const ( + EntryStatusMissing EntryStatus = iota + EntryStatusSelf // self + EntryStatusNonMember // peer that isn't in the raft cluster + EntryStatusMember // peer that is in the raft cluster + EntryStatusRecentMember // peer that has been seen recently + EntryStatusMatchingHostname // peer with the same hostname as we have +) + +func TestCreationLock(t *testing.T) { + s := miniredis.RunT(t) + peerdb := New(config.RedisConfig{ + Addrs: []string{s.Addr()}, + Name: "test", + }) + defer peerdb.Close() + + peer0 := [32]byte{byte(0)} + peer1 := [32]byte{byte(1)} + if err := peerdb.acquireCreationLock(context.Background(), peer0); err != nil { + t.Error(err) + } + if err := peerdb.acquireCreationLock(context.Background(), peer0); err == nil { + t.Error("lock can only be acquired once") + } + if err := peerdb.acquireCreationLock(context.Background(), peer1); err == nil { + t.Error("lock can only be acquired once") + } +} + +func TestFindRaftMember(t *testing.T) { + tests := []struct { + name string + peers []EntryStatus + expectedIdx int + }{ + {"pick_member_1", []EntryStatus{EntryStatusMember}, 0}, + {"pick_member", []EntryStatus{EntryStatusSelf, EntryStatusMember, EntryStatusNonMember}, 1}, + {"pick_recent", []EntryStatus{EntryStatusNonMember, EntryStatusMember, EntryStatusMember, EntryStatusRecentMember}, 3}, + {"no_members", []EntryStatus{EntryStatusSelf, EntryStatusNonMember, EntryStatusMissing, EntryStatusMatchingHostname}, 0}, + {"only_self", []EntryStatus{EntryStatusSelf}, 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := miniredis.RunT(t) + peerdb := New(config.RedisConfig{ + Addrs: []string{s.Addr()}, + Name: "test", + }) + defer peerdb.Close() + + oldTime := time.Unix(100, 0) + recentTime := time.Unix(101, 0) + peerdb.clock = util.TestAt(oldTime) + + // add peers to redis according to their status + self := [32]byte{byte(255)} + peers := make([]peerid.PeerID, len(tt.peers)) + for i := 0; i < len(tt.peers); i++ { + peers[i] = [32]byte{byte(i)} + hostname := fmt.Sprintf("host%v", i) + var err error + switch tt.peers[i] { + case EntryStatusMissing: + continue + case EntryStatusNonMember: + err = peerdb.Insert(context.Background(), peers[i], hostname, time.Minute) + case EntryStatusMember: + err = peerdb.JoinedRaft(context.Background(), peers[i], hostname, time.Minute) + case EntryStatusRecentMember: + peerdb.clock = util.TestAt(recentTime) + err = peerdb.JoinedRaft(context.Background(), peers[i], hostname, time.Minute) + peerdb.clock = util.TestAt(oldTime) + case EntryStatusMatchingHostname: + err = peerdb.JoinedRaft(context.Background(), peers[i], "self", time.Minute) + case EntryStatusSelf: + self = peers[i] + } + if err != nil { + t.Error(err) + } + } + + // add ourself to redis too + if err := peerdb.Insert(context.Background(), self, "self", time.Minute); err != nil { + t.Error(err) + } + + if tt.expectedIdx == -1 { + // expect a timeout + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) + defer cancel() + _, err := peerdb.FindRaftMember(ctx, self, "self") + if err != ctx.Err() { + t.Errorf("findRaftMember() = %v, want %v", err, ctx.Err()) + } + return + } + + got, err := peerdb.FindRaftMember(context.Background(), self, "self") + if err != nil { + t.Error(err) + } + if want := peers[tt.expectedIdx]; got != want { + t.Errorf("FindRaftMember() = %v, want %v", got, want) + } + }) + } +} diff --git a/host/peer/sequence_number.go b/host/peer/sequence_number.go new file mode 100644 index 0000000..96208a7 --- /dev/null +++ b/host/peer/sequence_number.go @@ -0,0 +1,74 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package peer + +import ( + "errors" + "fmt" + + pb "github.com/signalapp/svr2/proto" +) + +type sequenceNumber struct { + epoch uint32 + seq uint64 +} + +// follows returns true if this seqeunce number directly follows the provided sequence number. +// This is the case if either it's the next in the sequence, or it is from a greater epoch. +func (s sequenceNumber) follows(pred sequenceNumber) bool { + return (s.epoch > pred.epoch && s.seq == 0) || (s.epoch == pred.epoch && s.seq == pred.seq+1) +} + +// next returns the sequenceNumber of the next message within an enclave peer session +func (s sequenceNumber) next() sequenceNumber { + return sequenceNumber{ + epoch: s.epoch, + seq: s.seq + 1, + } +} + +// nextEpoch returns the sequenceNumber of the next epoch, typically denoting the start of a new +// enclave peer session +func (s sequenceNumber) nextEpoch() sequenceNumber { + return sequenceNumber{ + epoch: s.epoch + 1, + seq: 0, + } +} + +// cmp compares two sequence numbers +// +// returns: +// +// < 0 if this is less than the provided sequenceNumber +// == 0 if this is equal to the provided sequenceNumber +// > 0 if this is greater than the provided sequenceNumber +func (s sequenceNumber) cmp(o sequenceNumber) int { + if ecmp := int(s.epoch - o.epoch); ecmp != 0 { + return ecmp + } + return int(s.seq - o.seq) +} + +func (s sequenceNumber) proto() *pb.SequenceNumber { + return &pb.SequenceNumber{ + Epoch: s.epoch, + Seq: s.seq, + } +} + +func makeSeqno(p *pb.SequenceNumber) (sequenceNumber, error) { + if p == nil { + return sequenceNumber{}, errors.New("expected a sequence number present on message") + } + return sequenceNumber{ + epoch: p.Epoch, + seq: p.Seq, + }, nil +} + +func (s sequenceNumber) String() string { + return fmt.Sprintf("%v:%v", s.epoch, s.seq) +} diff --git a/host/peer/serialize.go b/host/peer/serialize.go new file mode 100644 index 0000000..51e6548 --- /dev/null +++ b/host/peer/serialize.go @@ -0,0 +1,154 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package peer + +import ( + "bufio" + "encoding/binary" + "errors" + "fmt" + "io" + + "google.golang.org/protobuf/proto" + + "github.com/signalapp/svr2/peerid" + pb "github.com/signalapp/svr2/proto" +) + +/* + * The host protocol supports only PeerConnectionMessages + * + * Each message is sent by first writing a varint length, + * followed be the message contents + */ + +func writeFramed(w io.Writer, m *pb.PeerConnectionMessage) error { + bs, err := proto.Marshal(m) + if err != nil { + return err + } + var frameBuf [binary.MaxVarintLen64]byte + n := binary.PutUvarint(frameBuf[:], uint64(len(bs))) + if _, err = w.Write(frameBuf[:n]); err != nil { + return err + } + + _, err = w.Write(bs) + return err +} + +func readFramed(r *bufio.Reader) (*pb.PeerConnectionMessage, error) { + bs, err := readFramedRaw(r) + if err != nil { + return nil, err + } + var msg pb.PeerConnectionMessage + if err := proto.Unmarshal(bs, &msg); err != nil { + return nil, err + } + return &msg, nil +} + +func readFramedRaw(r *bufio.Reader) ([]byte, error) { + length, err := binary.ReadUvarint(r) + if err != nil { + return nil, err + } + if length > maxMessageLength { + return nil, fmt.Errorf("message of length %v too long (or corrupt)", length) + } + + dst := make([]byte, length) + if _, err = io.ReadFull(r, dst); err != nil { + return nil, err + } + return dst, nil +} + +func writeAck(w io.Writer, seqno sequenceNumber) error { + return writeFramed(w, &pb.PeerConnectionMessage{ + Inner: &pb.PeerConnectionMessage_DataAck{ + DataAck: &pb.PeerConnectionDataAck{Seqno: seqno.proto()}, + }, + }) +} + +func readAck(r *bufio.Reader) (sequenceNumber, error) { + msg, err := readFramed(r) + if err != nil { + return sequenceNumber{}, fmt.Errorf("readAck: %w", err) + } + in, ok := msg.Inner.(*pb.PeerConnectionMessage_DataAck) + if !ok { + return sequenceNumber{}, errors.New("readAck: unexpected peer connection message") + } + seqno, err := makeSeqno(in.DataAck.Seqno) + if err != nil { + return sequenceNumber{}, errors.New("readAck: no sequence number provided") + } + return seqno, nil +} + +func writeHello(w io.Writer, from, to peerid.PeerID) error { + return writeFramed( + w, + &pb.PeerConnectionMessage{ + Inner: &pb.PeerConnectionMessage_Hello{ + Hello: &pb.PeerConnectionHello{FromPeerId: from[:], ToPeerId: to[:]}, + }, + }, + ) +} + +func readHello(r *bufio.Reader) (from, to peerid.PeerID, returnedErr error) { + msg, err := readFramed(r) + if err != nil { + returnedErr = fmt.Errorf("readHello: %w", err) + return + } + in, ok := msg.Inner.(*pb.PeerConnectionMessage_Hello) + if !ok { + returnedErr = errors.New("hello: unexpected peer connection message") + return + } + from, returnedErr = peerid.Make(in.Hello.FromPeerId) + if returnedErr != nil { + return + } + to, returnedErr = peerid.Make(in.Hello.ToPeerId) + if returnedErr != nil { + return + } + return +} + +func writeHelloAck(w io.Writer, lastAck sequenceNumber) error { + return writeFramed( + w, + &pb.PeerConnectionMessage{ + Inner: &pb.PeerConnectionMessage_HelloAck{ + HelloAck: &pb.PeerConnectionHelloAck{ + LastAck: lastAck.proto(), + }, + }, + }, + ) +} + +func readHelloAck(r *bufio.Reader) (sequenceNumber, error) { + var seqno sequenceNumber + + msg, err := readFramed(r) + if err != nil { + return seqno, fmt.Errorf("readHelloAck: %w", err) + } + in, ok := msg.Inner.(*pb.PeerConnectionMessage_HelloAck) + if !ok { + return seqno, errors.New("hello ack: unexpected peer connection message") + } + if seqno, err = makeSeqno(in.HelloAck.LastAck); err != nil { + return seqno, err + } + return seqno, nil +} diff --git a/host/peer/server.go b/host/peer/server.go new file mode 100644 index 0000000..05ffb2f --- /dev/null +++ b/host/peer/server.go @@ -0,0 +1,307 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package peer + +import ( + "bufio" + "context" + "fmt" + "net" + "strconv" + "sync" + + "github.com/signalapp/svr2/logger" + "github.com/signalapp/svr2/peerid" + "golang.org/x/sync/errgroup" + + metrics "github.com/armon/go-metrics" + pb "github.com/signalapp/svr2/proto" +) + +// PeerServer implements the host to host communication protocol +// +// Clients send messages to the server that are routed to the server's local enclave. +// The PeerServer responds with acks only. +// +// Internally, PeerServer accepts new connections, identifies the initiating peer, +// and passes the connection to the appropriate peerReceiver responsible for that peer +type PeerServer struct { + esender EnclaveSender + me peerid.PeerID + eg *errgroup.Group + ctx context.Context + + receiversMu sync.Mutex + receivers map[peerid.PeerID]*peerReceiver +} + +// NewPeerServer creates a new peer server which must be started with Listen +func NewPeerServer(ctx context.Context, me peerid.PeerID, enclaveSender EnclaveSender) *PeerServer { + eg, ctx := errgroup.WithContext(ctx) + return &PeerServer{ + esender: enclaveSender, + me: me, + receivers: make(map[peerid.PeerID]*peerReceiver), + eg: eg, + ctx: ctx, + } +} + +var ( + connectCounter = []string{"peer", "server", "connect"} + activeClientGauge = []string{"peer", "server", "activeClients"} + reconnectCounter = []string{"peer", "server", "reconnect"} + disconnectCounter = []string{"peer", "server", "disconnect"} + receiveCounter = []string{"peer", "server", "receive"} +) + +// Listen for new connections on addr +// +// Returns only after cancellation or a fatal error is encountered. Listen takes ownership +// of calling Close on the provided net.Listener +func (p *PeerServer) Listen(ln net.Listener) error { + + p.eg.Go(func() error { + for { + conn, err := ln.Accept() + if err != nil { + return err + } + + metrics.IncrCounter(connectCounter, 1) + p.eg.Go(func() error { + p.handleConnection(conn) + return nil + }) + } + }) + <-p.ctx.Done() + // stop the listener so accept unblocks + ln.Close() + return p.eg.Wait() +} + +// handleConnection initiates the peer handshake, and then hands-off the +// connection to a peerReceiver +// +// If this is the first connection for a peerID, the receiver is created, +// otherwise the receiver is just notified of the new connection +func (p *PeerServer) handleConnection(conn net.Conn) { + reader := bufio.NewReader(conn) + them, us, err := readHello(reader) + if err != nil { + logger.Warnw("failed to read hello from client", + "peer", conn.RemoteAddr(), + "err", err) + conn.Close() + return + } + if us != p.me { + logger.Warnw("got incorrect peer ID", "peer", conn.RemoteAddr(), "peerID", them) + conn.Close() + return + } + + logger.Infow("received connect attempt from peer", "peer", conn.RemoteAddr(), "peerID", them) + + // notify the receiver that it should switch over to this connection (replacing + // any existing connection). Note that it's still possible this kicks off after we've + // been cancelled -- nbd, the receiveLoop will bail out immediately + select { + case <-p.ctx.Done(): + conn.Close() + case p.getOrCreate(them).conns <- conn: + } +} + +// getOrCreate looks up the peerReceiver, creating one if it doesn't exist +func (p *PeerServer) getOrCreate(them peerid.PeerID) *peerReceiver { + p.receiversMu.Lock() + defer p.receiversMu.Unlock() + + receiver, ok := p.receivers[them] + if !ok { + metrics.SetGauge(activeClientGauge, float32(len(p.receivers))) + logger.Infow("received first connect from peer", "peerID", them) + // first connection from this peer, + // create a receiver and start it + receiver = &peerReceiver{ + local: p.me, + remote: them, + esender: p.esender, + conns: make(chan net.Conn)} + p.receivers[them] = receiver + p.eg.Go(func() error { return receiver.receiveLoop(p.ctx) }) + } + return receiver + +} + +// peerReceiver handles inbound messages +// from a single remote peerID +// +// Only a single connection is allowed, +// and on a reconnect from the same peer +// the previous connection is first shutdown +type peerReceiver struct { + seqno sequenceNumber // The sequence number of the last message received + conns chan net.Conn // Newly accepted connections for this peerID + local peerid.PeerID + remote peerid.PeerID + esender EnclaveSender +} + +// updateSeqno updates the sequence number after a message is received +// +// Returns true if the sequence number is new +// Returns an error if the sequence number is invalid to see +// in the current state +func (p *peerReceiver) updateSeqno(seqno sequenceNumber) (bool, error) { + if seqno.cmp(p.seqno) <= 0 { + return false, nil + } + if !seqno.follows(p.seqno) { + return false, fmt.Errorf("expected message seqno=%v, got %v", p.seqno.next(), seqno) + } + p.seqno = seqno + return true, nil +} + +// receiveLoop spins up a handler for inbound connections, cancelling and +// replacing the existing handler if a peer reconnects +// +// runs until it is cancelled +func (p *peerReceiver) receiveLoop(ctx context.Context) error { + labels := []metrics.Label{{Name: "peerID", Value: p.remote.String()}} + done := make(chan error, 1) + var lastConn net.Conn + for { + select { + case <-ctx.Done(): + // cancelled, stop the conn handler + if lastConn != nil { + lastConn.Close() // close the current connection + <-done // wait until done (ignore error we probably caused) + } + return ctx.Err() + case conn := <-p.conns: + // the same peer reconnected, replace the conn handler + if lastConn != nil { + metrics.IncrCounterWithLabels(reconnectCounter, 1, labels) + logger.Infow("peer client reconnected", + "peer_id", p.remote, + "peer", lastConn.RemoteAddr(), + ) + lastConn.Close() // close the current connection + <-done // wait until done (ignore error we probably caused) + } + lastConn = conn + + // spin up handler for new connection + go func() { done <- p.handleConnection(conn) }() + case err := <-done: + metrics.IncrCounterWithLabels(disconnectCounter, 1, labels) + // finished without an external connection close, + // log and wait for the next connect + logger.Warnw("error in receive handler", + "err", err, + "peer_id", p.remote, + "peer", lastConn.RemoteAddr()) + lastConn = nil + } + } +} + +func (p *peerReceiver) handleConnection(conn net.Conn) error { + defer conn.Close() + + logger.Debugw("sending helloAck to peer", + "seqno", p.seqno, + "peerID", p.remote, + "peer", conn.RemoteAddr(), + ) + peerLabel := metrics.Label{Name: "peerID", Value: p.remote.String()} + + // Before getting to this handler, we should have read + // the peer's initial hello. Now respond with our current + // sequence number so the client knows where to start sending + if err := writeHelloAck(conn, p.seqno); err != nil { + return err + } + + reader := bufio.NewReader(conn) + + // the server read loop + // 1. read a request from the client + // 2. update our latest sequence number + // 3. if this a new request, forward it to the enclave + // 4. write an ack back to the client + // 5. repeat + for { + // read message from client + pcm, err := readFramed(reader) + if err != nil { + return fmt.Errorf("read data: %w", err) + } + + msg, ok := pcm.Inner.(*pb.PeerConnectionMessage_Data) + if !ok { + // log and ignore, might be a message we don't know about + logger.Errorw("Received unknown message", + "peer", conn.RemoteAddr(), + "peerID", p.remote, + ) + continue + } + + msgSeqno, err := makeSeqno(msg.Data.Seqno) + if err != nil { + return err + } + isNew, err := p.updateSeqno(msgSeqno) + if err != nil { + return err + } + + metrics.IncrCounterWithLabels(receiveCounter, 1, []metrics.Label{ + {Name: "duplicate", Value: strconv.FormatBool(!isNew)}, + peerLabel, + }) + + if !isNew { + // we've already delievered this message + if err := writeAck(conn, p.seqno); err != nil { + return err + } + continue + } + + // this is a new message, forward it to enclave + u := pb.UntrustedMessage{ + Inner: &pb.UntrustedMessage_PeerMessage{ + PeerMessage: &pb.PeerMessage{ + PeerId: p.remote[:], + Inner: msg.Data.Msg.Inner, + }, + }, + } + + logger.Debugw("received new message from peer, forwarding to enclave", + "peer", conn.RemoteAddr(), + "peerID", p.remote, + "seqno", msgSeqno, + "type", fmt.Sprintf("%T", msg.Data.Msg.Inner), + ) + + // It is required that we do not send into the enclave + // with any concurrency. We must wait for the previous message + // from a peer to be processed before sending a new one + if _, err := p.esender.Send(&u); err != nil { + return err + } + if err := writeAck(conn, msgSeqno); err != nil { + return err + } + } +} diff --git a/host/peer/server_test.go b/host/peer/server_test.go new file mode 100644 index 0000000..54b3f7d --- /dev/null +++ b/host/peer/server_test.go @@ -0,0 +1,355 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package peer + +import ( + "bufio" + "bytes" + "context" + "fmt" + "net" + "testing" + + "github.com/signalapp/svr2/peerid" + pb "github.com/signalapp/svr2/proto" +) + +type mockEnclave struct { + msgs []*pb.UntrustedMessage +} + +func (t *mockEnclave) Send(p *pb.UntrustedMessage) (*pb.EnclaveMessage, error) { + t.msgs = append(t.msgs, p) + return nil, nil +} + +type receiverFixture struct { + enclave *mockEnclave + pr *PeerServer + addr string + done chan error + ctx context.Context + cancel context.CancelFunc +} + +func (f *receiverFixture) close() error { + f.cancel() + return <-f.done +} + +func (*receiverFixture) seq(i uint64) sequenceNumber { + return sequenceNumber{epoch: 0, seq: i} +} + +func startReceiver(t *testing.T) *receiverFixture { + e := &mockEnclave{} + + ctx, cancel := context.WithCancel(context.Background()) + pr := NewPeerServer(ctx, peer(0), e) + ln, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + addr := ln.Addr().String() + + done := make(chan error) + go func() { done <- pr.Listen(ln) }() + return &receiverFixture{ + e, + pr, + addr, + done, + ctx, + cancel, + } +} + +func (f *receiverFixture) dial(t *testing.T) net.Conn { + conn, err := net.Dial("tcp", f.addr) + if err != nil { + t.Fatal(err) + } + return conn +} + +func (f *receiverFixture) handshake(t *testing.T, c net.Conn, initiatorId peerid.PeerID) sequenceNumber { + reader := bufio.NewReader(c) + if err := writeHello(c, initiatorId, f.pr.me); err != nil { + t.Error(err) + } + ack, err := readHelloAck(reader) + if err != nil { + t.Error(err) + } + return ack +} + +func (*receiverFixture) trySendAck(c net.Conn, data []byte, seqno sequenceNumber) error { + if err := writeFramed(c, &pb.PeerConnectionMessage{ + Inner: &pb.PeerConnectionMessage_Data{ + Data: &pb.PeerConnectionData{ + Seqno: seqno.proto(), + Msg: &pb.PeerMessage{ + Inner: &pb.PeerMessage_Syn{Syn: data}, + }, + }, + }, + }); err != nil { + return err + } + + ack, err := readAck(bufio.NewReader(c)) + if err != nil { + return err + } + if ack != seqno { + return fmt.Errorf("readAck()=%v, wanted %v", ack, seqno) + } + return nil +} + +func (f *receiverFixture) sendAck(t *testing.T, c net.Conn, data []byte, seqno sequenceNumber) { + if err := f.trySendAck(c, data, seqno); err != nil { + t.Error(err) + } +} + +func TestSimpleAck(t *testing.T) { + f := startReceiver(t) + defer f.close() + conn := f.dial(t) + defer conn.Close() + + ack := f.handshake(t, conn, peer(1)) + if ack != (sequenceNumber{}) { + t.Errorf("handshake ack = %v, want %v", ack, 0) + } + + f.sendAck(t, conn, []byte("data"), f.seq(1)) + + if len(f.enclave.msgs) != 1 { + t.Errorf("got %v forwarded messages, want %v", len(f.enclave.msgs), 1) + } + msg := f.enclave.msgs[0] + if d, ok := msg.GetPeerMessage().Inner.(*pb.PeerMessage_Syn); !ok { + t.Errorf("not syn") + } else if !bytes.Equal(d.Syn, []byte("data")) { + t.Errorf("got data %v, want %v", d.Syn, []byte("data")) + } +} + +func TestSequenceGap(t *testing.T) { + f := startReceiver(t) + defer f.close() + conn := f.dial(t) + defer conn.Close() + + if ack := f.handshake(t, conn, peer(1)); ack != (sequenceNumber{}) { + t.Errorf("handshake ack = %v, want %v", ack, 0) + } + f.sendAck(t, conn, []byte("data"), f.seq(1)) + if err := f.trySendAck(conn, []byte("data"), f.seq(3)); err == nil { + t.Fatal("expected error sending ack") + } +} +func TestEpochGap(t *testing.T) { + f := startReceiver(t) + defer f.close() + conn := f.dial(t) + defer conn.Close() + + if ack := f.handshake(t, conn, peer(1)); ack != (sequenceNumber{}) { + t.Errorf("handshake ack = %v, want %v", ack, 0) + } + f.sendAck(t, conn, []byte("data"), sequenceNumber{0, 1}) + f.sendAck(t, conn, []byte("data"), sequenceNumber{2, 0}) // skipping epoch is ok + + // skipping a message in epoch 3 is not ok + if err := f.trySendAck(conn, []byte("data"), sequenceNumber{3, 1}); err == nil { + t.Fatal("expected error sending ack") + } +} + +func TestEpochAck(t *testing.T) { + f := startReceiver(t) + defer f.close() + conn := f.dial(t) + defer conn.Close() + + ack := f.handshake(t, conn, peer(1)) + if ack != (sequenceNumber{}) { + t.Errorf("handshake ack = %v, want %v", ack, 0) + } + + f.sendAck(t, conn, []byte("data"), sequenceNumber{0, 1}) + f.sendAck(t, conn, []byte("data"), sequenceNumber{2, 0}) + f.sendAck(t, conn, []byte("data"), sequenceNumber{2, 1}) + f.sendAck(t, conn, []byte("data"), sequenceNumber{3, 0}) + + if len(f.enclave.msgs) != 4 { + t.Errorf("got %v forwarded messages, want %v", len(f.enclave.msgs), 4) + } +} + +func TestDisconnectAck(t *testing.T) { + f := startReceiver(t) + defer f.close() + + conn := f.dial(t) + ack := f.handshake(t, conn, peer(1)) + if ack != (sequenceNumber{}) { + t.Errorf("handshake ack = %v, want %v", ack, 0) + } + + f.sendAck(t, conn, []byte("data"), f.seq(1)) + f.sendAck(t, conn, []byte("data1"), f.seq(2)) + f.sendAck(t, conn, []byte("data2"), f.seq(3)) + if len(f.enclave.msgs) != 3 { + t.Errorf("got %v forwarded messages, want %v", len(f.enclave.msgs), 3) + } + + conn.Close() + conn = f.dial(t) + ack = f.handshake(t, conn, peer(1)) + if ack != f.seq(3) { + t.Errorf("handshake ack = %v, want %v", ack, 3) + } + f.sendAck(t, conn, []byte("data4"), f.seq(4)) + if len(f.enclave.msgs) != 4 { + t.Errorf("got %v forwarded messages, want %v", len(f.enclave.msgs), 4) + } +} + +func TestDisconnectEpochAck(t *testing.T) { + f := startReceiver(t) + defer f.close() + + conn := f.dial(t) + ack := f.handshake(t, conn, peer(1)) + if ack != (sequenceNumber{}) { + t.Errorf("handshake ack = %v, want %v", ack, 0) + } + + f.sendAck(t, conn, []byte("data"), sequenceNumber{5, 0}) + if len(f.enclave.msgs) != 1 { + t.Errorf("got %v forwarded messages, want %v", len(f.enclave.msgs), 3) + } + + conn.Close() + conn = f.dial(t) + if ack = f.handshake(t, conn, peer(1)); ack != (sequenceNumber{5, 0}) { + t.Errorf("handshake ack = %v, want %v", ack, "5:0") + } + f.sendAck(t, conn, []byte("data4"), sequenceNumber{5, 1}) + if len(f.enclave.msgs) != 2 { + t.Errorf("got %v forwarded messages, want %v", len(f.enclave.msgs), 2) + } +} + +func TestResend(t *testing.T) { + f := startReceiver(t) + defer f.close() + + conn := f.dial(t) + ack := f.handshake(t, conn, peer(1)) + if ack != (sequenceNumber{}) { + t.Errorf("handshake ack = %v, want %v", ack, 0) + } + + f.sendAck(t, conn, []byte("data"), f.seq(1)) + f.sendAck(t, conn, []byte("data"), f.seq(1)) +} + +func TestReconnectingClient(t *testing.T) { + f := startReceiver(t) + defer f.close() + + conn := f.dial(t) + ack := f.handshake(t, conn, peer(1)) + if ack != (sequenceNumber{}) { + t.Errorf("handshake ack = %v, want %v", ack, 0) + } + + // reconnect, server should disconnect old connection + conn2 := f.dial(t) + ack = f.handshake(t, conn2, peer(1)) + if ack != (sequenceNumber{}) { + t.Errorf("handshake ack = %v, want %v", ack, 0) + } + + bs := []byte{0} + _, err := conn.Read(bs) + if err == nil { + t.Error("expected disconnect for original client") + } +} + +func TestMultiplePeers(t *testing.T) { + f := startReceiver(t) + defer f.close() + + conn1 := f.dial(t) + ack1 := f.handshake(t, conn1, peer(1)) + if ack1 != (sequenceNumber{}) { + t.Errorf("handshake ack = %v, want %v", ack1, 0) + } + + conn2 := f.dial(t) + ack2 := f.handshake(t, conn2, peer(2)) + if ack2 != (sequenceNumber{}) { + t.Errorf("handshake ack = %v, want %v", ack2, 0) + } + + f.sendAck(t, conn1, []byte("data1"), f.seq(1)) + f.sendAck(t, conn2, []byte("data2"), f.seq(1)) + + if len(f.enclave.msgs) != 2 { + t.Errorf("got %v forwarded messages, want %v", len(f.enclave.msgs), 2) + } + got1 := f.enclave.msgs[0].GetPeerMessage().Inner.(*pb.PeerMessage_Syn).Syn + got2 := f.enclave.msgs[1].GetPeerMessage().Inner.(*pb.PeerMessage_Syn).Syn + + if !bytes.Equal(got1, []byte("data1")) { + t.Errorf("got data %v, want %v", got1, []byte("data1")) + } + + if !bytes.Equal(got2, []byte("data2")) { + t.Errorf("got data %v, want %v", got2, []byte("data2")) + } + +} + +func TestEndToEnd(t *testing.T) { + f := startReceiver(t) + defer f.close() + + sender := createSender(withLookup(serverID, f.addr)) + done := make(chan error) + ctx, cancel := context.WithCancel(context.Background()) + go func() { done <- sender.run(ctx) }() + + msg := pb.PeerMessage{ + PeerId: f.pr.me[:], + Inner: &pb.PeerMessage_Syn{Syn: []byte("data")}, + } + + *sender.tx.Load() <- &msg + cancel() + <-done +} + +func TestRejectsNotMyPeerID(t *testing.T) { + f := startReceiver(t) + defer f.close() + conn := f.dial(t) + defer conn.Close() + + reader := bufio.NewReader(conn) + if err := writeHello(conn, peer(1), peer(12 /*not me*/)); err != nil { + t.Error(err) + } + if _, err := readHelloAck(reader); err == nil { + t.Fatal("accepted connection when toPeerID did not match server") + } +} diff --git a/host/peerid/peerid.go b/host/peerid/peerid.go new file mode 100644 index 0000000..65a1d8f --- /dev/null +++ b/host/peerid/peerid.go @@ -0,0 +1,39 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +// Package peerid provides the PeerID type. Enclaves identify remote peers by +// their PeerID, which is a a 256 bit public key that an enclave generates at startup. +// Hosts are responsible for mapping PeerIDs to the actual remote endpoints used for +// network communication. +package peerid + +import ( + "encoding/hex" + "fmt" +) + +type PeerID [32]byte + +func Make(s []byte) (PeerID, error) { + if len(s) != 32 { + return PeerID{}, fmt.Errorf("incorrect peer id length %v", len(s)) + } + var out PeerID + copy(out[:], s) + return out, nil +} + +// Hex parses a hexidecimal formatted peerID +func FromHex(s string) (PeerID, error) { + bs, err := hex.DecodeString(s) + if err != nil { + return PeerID{}, err + } + return Make(bs) + +} + +// String returns a hexidecimal formatted peerID (just an 8-char prefix) +func (p PeerID) String() string { + return hex.EncodeToString(p[:4]) +} diff --git a/host/proto/control.proto b/host/proto/control.proto new file mode 100644 index 0000000..439d630 --- /dev/null +++ b/host/proto/control.proto @@ -0,0 +1,181 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +syntax = "proto3"; + +package svr2.control; +option go_package = "github.com/signalapp/svr2/proto"; +option optimize_for = LITE_RUNTIME; + +import "error.proto"; +import "msgs.proto"; +import "enclaveconfig.proto"; + +// +// untrusted peer to untrusted peer messages +// + +// Sequence numbers ensure that when a client and server are disconnected then reconnected, the +// server does not miss any previously sent messages. +// +// Each message exchanged between hosts is tagged with a SequenceNumber. A sequence number consists +// of a seq that increments on each new message, and an epoch. When a message is tagged with a new +// epoch, the recepient of the message can ignore all messages from previous epochs. +message SequenceNumber { + uint32 epoch = 1; + uint64 seq = 2; +} + +message PeerConnectionHello { + bytes from_peer_id = 1; + bytes to_peer_id = 2; +} + +message PeerConnectionHelloAck { + SequenceNumber last_ack = 1; +} + +message PeerConnectionData { + SequenceNumber seqno = 1; + PeerMessage msg = 2; +} + +message PeerConnectionDataAck { + SequenceNumber seqno = 1; +} + +message PeerConnectionMessage { + oneof inner { + PeerConnectionHello hello = 1; + PeerConnectionHelloAck hello_ack = 2; + PeerConnectionData data = 3; + PeerConnectionDataAck data_ack = 4; + } +} + +// +// control requests +// + +message ControlRequest { + uint64 id = 1; + oneof data { + NegotiateClientRequest negotiate_client_request = 3; + ClientEncryptedRequest client_encrypted_request = 4; + GetStatusControlRequest get_status_control_request = 5; + ForcePeerReconnectRequest force_peer_reconnect_request = 7; + PeerDisconnectRequest peer_disconnect_request = 11; + PeerPermanentDeleteRequest peer_permanent_delete_request = 12; + GetMetricsControlRequest get_metrics_control_request = 9; + TransactionControlRequest transaction_control_request = 10; + } +} + +message NegotiateClientRequest { + string enclave_name = 1; + bytes client_pubkey = 2; +} + +enum ClientRequestType { + NONE = 0; + BACKUP = 1; + RESTORE = 2; + DELETE = 3; +} + +message ClientEncryptedRequest { + string enclave_name = 1; + bytes backup_id = 4; + ClientRequestType request_type = 5; + ClientEncryptedMessage encrypted_message = 2; + bytes pending_request_id = 3; +} + +message ClientEncryptedMessage { + bytes iv = 1; + bytes mac = 2; + bytes data = 3; +} + +message GetStatusControlRequest { + bool memory_status = 1; +} + +message TransactionControlRequest { + string enclave_name = 1; + oneof data { + DatabaseRequest database_request = 3; + }; +} + +message ForcePeerReconnectRequest { + string enclave_name = 1; + bytes peer_id = 2; + string address = 3; +} + +message PeerDisconnectRequest { + string enclave_name = 1; + bytes peer_id = 2; +} + +message PeerPermanentDeleteRequest { + string enclave_name = 1; + bytes peer_id = 2; +} + +message GetMetricsControlRequest { +} + +// +// control replies +// + +message ControlReply { + uint64 id = 1; + reserved 6; + oneof data { + NegotiateClientReply negotiate_client_reply = 3; + ClientEncryptedReply client_encrypted_reply = 4; + GetStatusControlReply get_status_control_reply = 5; + GetMetricsControlReply get_metrics_control_reply = 8; + TransactionControlReply transaction_control_reply = 10; + + ControlErrorSignal control_error_signal = 2; + } +} + +message NegotiateClientReply { + bytes server_static_pubkey = 1; + bytes server_ephemeral_pubkey = 2; + ClientEncryptedMessage encrypted_pending_request_id = 3; +} + +message ClientEncryptedReply { + ClientEncryptedMessage encrypted_message = 1; +} + +message GetStatusControlReply { + repeated EnclaveStatus enclaves = 1; +} + +message EnclaveStatus { + string name = 1; + bytes peer_id = 2; + enclaveconfig.EnclaveConfig config = 3; + EnclaveStatus status = 4; +} + +message GetMetricsControlReply { + string metrics_json = 1; +} + +message TransactionControlReply { + oneof data { + error.Error status = 1; + } +} + +message ControlErrorSignal { + string reason = 1; +} diff --git a/host/proto/error.go b/host/proto/error.go new file mode 100644 index 0000000..fb015e5 --- /dev/null +++ b/host/proto/error.go @@ -0,0 +1,16 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package proto + +import ( + "fmt" +) + +// Error implements the `error` interface on the `Error` enum. +func (e Error) Error() string { + if str, ok := Error_name[int32(e)]; ok { + return str + } + return fmt.Sprintf("UnknownErrorEnumValue(%d)", int32(e)) +} diff --git a/host/proto/error_test.go b/host/proto/error_test.go new file mode 100644 index 0000000..d43d34b --- /dev/null +++ b/host/proto/error_test.go @@ -0,0 +1,20 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package proto + +import ( + "testing" +) + +func TestError(t *testing.T) { + var e error = Error_Core_NoInit + if got, want := e.Error(), "Core_NoInit"; got != want { + t.Errorf("got %q want %q", got, want) + } + + e = Error(-1) + if got, want := e.Error(), "UnknownErrorEnumValue(-1)"; got != want { + t.Errorf("got %q want %q", got, want) + } +} diff --git a/host/proto/peerdb.proto b/host/proto/peerdb.proto new file mode 100644 index 0000000..886cd5e --- /dev/null +++ b/host/proto/peerdb.proto @@ -0,0 +1,16 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +syntax = "proto3"; + +package svr2.peerdb; +option go_package = "github.com/signalapp/svr2/proto"; +option optimize_for = LITE_RUNTIME; + + +message PeerEntry { + int64 join_ts = 1; + int64 last_update_ts = 2; + string addr = 3; // hostname:port that other nodes can use to access this peer + bool raft_member = 4; // if true, this peer is a raft member +} diff --git a/host/raftmanager/raftmanager.go b/host/raftmanager/raftmanager.go new file mode 100644 index 0000000..b32ac16 --- /dev/null +++ b/host/raftmanager/raftmanager.go @@ -0,0 +1,191 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +// Package raftmanager provides utilities for joining or creating raft groups +package raftmanager + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/signalapp/svr2/config" + "github.com/signalapp/svr2/logger" + "github.com/signalapp/svr2/peerid" + pb "github.com/signalapp/svr2/proto" +) + +type MemberFinder interface { + // FindRaftMember finds an existing raft member to join, or return self if raft group should be created + FindRaftMember(ctx context.Context, me peerid.PeerID, localPeerAddr string) (peerid.PeerID, error) +} + +// EnclaveRequester provides a request/response channel to the enclave +type EnclaveRequester interface { + // SendTransaction sends a request to the enclave and returns the enclave's response. + // Implementations must tag the provided request with a requestID. + SendTransaction(req *pb.HostToEnclaveRequest) (*pb.HostToEnclaveResponse, error) +} + +// RaftManager can be used to create a new raft group, or join an existing one. +type RaftManager struct { + config *config.Config + me peerid.PeerID + localPeerAddr string + enclaveRequester EnclaveRequester + memberFinder MemberFinder + isMember bool // true once a successful call to CreateOrJoin is made +} + +func New(me peerid.PeerID, enclaveRequester EnclaveRequester, memberFinder MemberFinder, config *config.Config) *RaftManager { + return &RaftManager{ + config, + me, + config.PeerAddr, + enclaveRequester, + memberFinder, + false, + } +} + +// MarkLiveFun is used to indicate to other peers that +// this peer is a good candidate to join raft with +type MarkLiveFun func(ctx context.Context) error + +// RunRefresher periodically checks the enclave to verify this peer has good connectivity +// to other raft nodes, and if so calls the provided MarkLiveFun +func (r *RaftManager) RunRefresher(ctx context.Context, markLive MarkLiveFun) error { + if !r.isMember { + // only allow starting the refresher once we're part of raft + return errors.New("only raft members can update their peerdb status") + } + + // always initially mark ourselves as live + if err := markLive(ctx); err != nil { + return err + } + + logger.Infof("Starting raft status refresher, will refresh every %v", r.config.Raft.RefreshStatusDuration) + + // periodically query the enclave and make sure we're still a reasonable candidate for others to join + ch := time.Tick(r.config.Raft.RefreshStatusDuration) + for { + select { + case <-ch: + if err := r.enclaveJoinable(); err != nil { + logger.Warnw("could not get enclave status", "err", err) + } else if err := markLive(ctx); err != nil { + logger.Warnw("failed to mark ourselves as joinable", "err", err) + } + case <-ctx.Done(): + return ctx.Err() + } + } +} + +// enclaveJoinable returns nil if other peers may use this node to join a raft cluster +func (r *RaftManager) enclaveJoinable() error { + resp, err := r.enclaveRequester.SendTransaction(&pb.HostToEnclaveRequest{ + Inner: &pb.HostToEnclaveRequest_GetEnclaveStatus{ + GetEnclaveStatus: true, + }, + }) + if err != nil { + return err + } + if v, ok := resp.Inner.(*pb.HostToEnclaveResponse_Status); ok && v.Status != pb.Error_OK { + return v.Status + } + v, ok := resp.Inner.(*pb.HostToEnclaveResponse_GetEnclaveStatusReply) + if !ok { + return fmt.Errorf("unexpected enclave reply %v", resp) + } + + partitionStatus := v.GetEnclaveStatusReply + if partitionStatus == nil { + return errors.New("unexpected enclave reply, missing partition status") + } + + if partitionStatus.RaftState != pb.RaftState_RAFTSTATE_LOADED_PART_OF_GROUP { + return fmt.Errorf("not part of raft group, state: %v", partitionStatus.RaftState) + } + return nil +} + +// CreateOrJoin finds and joins an existing raft group or creates a new raft group +// if one cannot be found. Errors from this method indicate an enclave join operation +// has failed. +func (r *RaftManager) CreateOrJoin(ctx context.Context) error { + raftPeer, err := r.memberFinder.FindRaftMember(ctx, r.me, r.localPeerAddr) + if err != nil { + return errors.New("failed to fetch raft peers") + } + if raftPeer == r.me { + logger.Infow("attempting to create a new raft group") + if err := r.createRaft(); err != nil { + return err + } + } else { + logger.Infow("attempting to join existing raft group", "peerID", raftPeer) + if err := r.joinExistingRaftPeer(raftPeer); err != nil { + return err + } + } + + // refresh attestation now that we've joined the raft group + logger.Infow("attempting to refresh attestation") + if err := r.refreshAttestation(); err != nil { + return err + } + r.isMember = true + return nil +} + +func (r *RaftManager) refreshAttestation() error { + // send a message to the enclave to update attestation + return r.sendToEnclave(&pb.HostToEnclaveRequest{ + Inner: &pb.HostToEnclaveRequest_RefreshAttestation{ + RefreshAttestation: &pb.RefreshAttestation{RotateClientKey: false}, + }, + }) +} + +func (r *RaftManager) createRaft() error { + // send a message to the enclave to create a raft group + return r.sendToEnclave(&pb.HostToEnclaveRequest{ + Inner: &pb.HostToEnclaveRequest_CreateNewRaftGroup{CreateNewRaftGroup: true}, + }) +} + +// joinExistingRaftPeer joins a raft group by requesting membership from the provided peer +func (r *RaftManager) joinExistingRaftPeer(peer peerid.PeerID) error { + // send a message to the enclave to join the raft group using the provided peer + return r.sendToEnclave(&pb.HostToEnclaveRequest{ + Inner: &pb.HostToEnclaveRequest_JoinRaft{ + JoinRaft: &pb.JoinRaftRequest{PeerId: peer[:]}, + }, + }) +} + +// sendToEnclave sends a request to the enclave and verify the response is Error_OK +func (r *RaftManager) sendToEnclave(req *pb.HostToEnclaveRequest) error { + reply, err := r.enclaveRequester.SendTransaction(req) + + if err != nil { + return fmt.Errorf("failed to send to enclave : %w", err) + } + + if reply == nil { + return fmt.Errorf("unexpected enclave response (nil)") + } + + status, ok := reply.Inner.(*pb.HostToEnclaveResponse_Status) + if !ok { + return errors.New("unexpected enclave response") + } + if status.Status != pb.Error_OK { + return fmt.Errorf("failed to join raft: %w", status.Status) + } + return nil +} diff --git a/host/raftmanager/raftmanager_test.go b/host/raftmanager/raftmanager_test.go new file mode 100644 index 0000000..67aa2bc --- /dev/null +++ b/host/raftmanager/raftmanager_test.go @@ -0,0 +1,76 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package raftmanager + +import ( + "context" + "testing" + + "github.com/signalapp/svr2/config" + "github.com/signalapp/svr2/peerid" + + pb "github.com/signalapp/svr2/proto" +) + +type mockEnclave struct { + msgs []*pb.HostToEnclaveRequest +} + +func (t *mockEnclave) SendTransaction(p *pb.HostToEnclaveRequest) (*pb.HostToEnclaveResponse, error) { + t.msgs = append(t.msgs, p) + return &pb.HostToEnclaveResponse{ + Inner: &pb.HostToEnclaveResponse_Status{ + Status: pb.Error_OK, + }, + }, nil +} + +type mockFinder peerid.PeerID + +func (f mockFinder) FindRaftMember(ctx context.Context, me peerid.PeerID, addr string) (peerid.PeerID, error) { + return peerid.PeerID(f), nil +} + +func TestCreate(t *testing.T) { + peer0 := [32]byte{byte(0)} + mockEnclave := &mockEnclave{} + + raftManager := New(peer0, mockEnclave, mockFinder(peer0), config.Default()) + if err := raftManager.CreateOrJoin(context.Background()); err != nil { + t.Error(err) + } + if len(mockEnclave.msgs) != 2 { + t.Errorf("want %v enclave messages, got %v", 2, len(mockEnclave.msgs)) + } + + if mockEnclave.msgs[0].GetCreateNewRaftGroup() == false { + t.Errorf("got message %v, want create request", mockEnclave.msgs[0]) + } + if mockEnclave.msgs[1].GetRefreshAttestation() == nil { + t.Errorf("got message %v, want refresh request", mockEnclave.msgs[1]) + } +} + +func TestJoin(t *testing.T) { + peer0 := [32]byte{byte(0)} + peer1 := [32]byte{byte(1)} + + mockEnclave := &mockEnclave{} + + raftManager := New(peer0, mockEnclave, mockFinder(peer1), config.Default()) + if err := raftManager.CreateOrJoin(context.Background()); err != nil { + t.Error(err) + } + if len(mockEnclave.msgs) != 2 { + t.Errorf("want %v enclave messages, got %v", 2, len(mockEnclave.msgs)) + } + + if mockEnclave.msgs[0].GetJoinRaft() == nil { + t.Errorf("got message %v, want create request", mockEnclave.msgs[0]) + } + + if mockEnclave.msgs[1].GetRefreshAttestation() == nil { + t.Errorf("got message %v, want create request", mockEnclave.msgs[1]) + } +} diff --git a/host/rate/rate.go b/host/rate/rate.go new file mode 100644 index 0000000..3151637 --- /dev/null +++ b/host/rate/rate.go @@ -0,0 +1,76 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +// Package rate provides rate limiters that may be used to limit the +// number of operations performed during some time period. +package rate + +import ( + "context" + "fmt" + "time" + + "github.com/go-redis/redis/v8" + "github.com/go-redis/redis_rate/v9" + "github.com/signalapp/svr2/config" +) + +// ErrLimitExceeded indicates that the requested permit exceeds +// the configured rate limit for the user. After waiting for +// RetryAfter, the same request should succeed +type ErrLimitExceeded struct{ RetryAfter time.Duration } + +func (e ErrLimitExceeded) Error() string { + return fmt.Sprintf("rate limit excceded, retry after : %v", e.RetryAfter) +} + +type Limiter interface { + // Limit checks and enforces the rate limit. Returns nil if the operation + // may proceed, or ErrLimitExceeded if the operation would exceed the + // rate limit. If the rate limit cannot be checked, a different error + // may be returned. + Limit(ctx context.Context, key string) error +} + +// NewRedisLimiter returns a Limiter backed by redis +func NewRedisLimiter(cfg *config.Config) Limiter { + return &redisLimiter{ + fmt.Sprintf("%s::leaky_bucket", cfg.Redis.Name), + redis_rate.Limit{ + Rate: cfg.Limit.LeakRateScalar, + Burst: cfg.Limit.BucketSize, + Period: cfg.Limit.LeakRateDuration, + }, + redis_rate.NewLimiter(redis.NewClusterClient(&redis.ClusterOptions{ + Addrs: cfg.Redis.Addrs, + Password: cfg.Redis.Password, + }))} +} + +type redisLimiter struct { + prefix string // prefix for rate limit buckets + limit redis_rate.Limit // configured limit + limiter *redis_rate.Limiter +} + +func (r *redisLimiter) Limit(ctx context.Context, key string) error { + res, err := r.limiter.Allow(ctx, r.redisKey(key), r.limit) + if err != nil { + return err + } + if res.Allowed <= 0 { + return ErrLimitExceeded{res.RetryAfter} + } + return nil +} + +func (r *redisLimiter) redisKey(userKey string) string { + return fmt.Sprintf("%s::%s", r.prefix, userKey) +} + +type alwaysAllow struct{} + +func (r alwaysAllow) Limit(context.Context, string) error { return nil } + +// AlwaysAllow provides a Limiter that will always allow callers through +var AlwaysAllow = Limiter(alwaysAllow{}) diff --git a/host/rate/rate_test.go b/host/rate/rate_test.go new file mode 100644 index 0000000..36110b4 --- /dev/null +++ b/host/rate/rate_test.go @@ -0,0 +1,90 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package rate + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/signalapp/svr2/config" +) + +func testLimiter(t *testing.T, limitConfig config.RateLimitConfig) Limiter { + cfg := config.Default() + cfg.Limit = limitConfig + s := miniredis.RunT(t) + cfg.Redis.Addrs = []string{s.Addr()} + return NewRedisLimiter(cfg) +} + +func TestRedisLimiter(t *testing.T) { + cfg := config.RateLimitConfig{ + BucketSize: 2, + LeakRateScalar: 1, + LeakRateDuration: 100 * time.Millisecond, + } + limiter := testLimiter(t, cfg) + + var count int + var err error + var exceeded ErrLimitExceeded + + // wait until we get a limit exceeded error + for count = 0; !errors.As(err, &exceeded); count++ { + err = limiter.Limit(context.Background(), "test1") + } + if count < 2 { + t.Errorf("took %v limits, should be at least %v", count, cfg.BucketSize) + } + + start := time.Now() + + // keep trying until we don't get limited + for errors.As(err, &exceeded) { + err = limiter.Limit(context.Background(), "test1") + time.Sleep(cfg.LeakRateDuration / 10) + } + duration := time.Since(start) + + if err != nil { + t.Error(err) + } + if duration > cfg.LeakRateDuration*2 { + t.Errorf("took %v to get a permit, should only need %v", duration, cfg.LeakRateDuration) + } +} + +func TestRedisLimiterExhaust(t *testing.T) { + limiter := testLimiter(t, config.RateLimitConfig{ + BucketSize: 10, + LeakRateScalar: 1, + LeakRateDuration: time.Hour, + }) + for i := 0; i < 10; i++ { + if err := limiter.Limit(context.Background(), "test1"); err != nil { + t.Errorf("iter %v : %v", i, err) + } + } + + // 11th request should get rate limited + var errExceed ErrLimitExceeded + err := limiter.Limit(context.Background(), "test1") + if !errors.As(err, &errExceed) { + t.Fatalf("Limit(11)=%v, want %v", err, "ErrLimitExceeded") + } + if errExceed.RetryAfter < (time.Hour - 5*time.Second) { + t.Fatalf("RetryAfter = %v, should be at least %v", errExceed.RetryAfter, time.Hour-5*time.Second) + } + if errExceed.RetryAfter > time.Hour { + t.Fatalf("RetryAfter = %v, should be at most %v", errExceed.RetryAfter, time.Hour) + } + + // unrelated key should be fine + if err := limiter.Limit(context.Background(), "test2"); err != nil { + t.Error(err) + } +} diff --git a/host/service/service.go b/host/service/service.go new file mode 100644 index 0000000..2ba472b --- /dev/null +++ b/host/service/service.go @@ -0,0 +1,162 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +// Package service creates and initializes all components required to run an SVR instance +package service + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "github.com/signalapp/svr2/auth" + "github.com/signalapp/svr2/config" + "github.com/signalapp/svr2/dispatch" + "github.com/signalapp/svr2/enclave" + "github.com/signalapp/svr2/health" + "github.com/signalapp/svr2/logger" + "github.com/signalapp/svr2/peer" + "github.com/signalapp/svr2/peer/peerdb" + "github.com/signalapp/svr2/raftmanager" + "github.com/signalapp/svr2/rate" + "github.com/signalapp/svr2/util" + "github.com/signalapp/svr2/web/handlers" + "github.com/signalapp/svr2/web/middleware" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" + + pb "github.com/signalapp/svr2/proto" + _ "net/http/pprof" +) + +// Start starts all SVR components and only returns when a component has encountered an +// unrecoverable error or the provided context has been cancelled. +func Start(ctx context.Context, econfig *pb.InitConfig, hconfig *config.Config, enclavePath string, authenticator auth.Auth) error { + g, ctx := errgroup.WithContext(ctx) + + // Start up the control server immediately, for debugging and liveness checking. + // Use DefaultServeMux, since it's got PProf stuff already attached by net/http/pprof. + controlMux := http.DefaultServeMux + healthErr := errors.New("joining raft") + live, ready := health.New(healthErr), health.New(healthErr) + controlMux.Handle("/health/live", middleware.Instrument(live)) + controlMux.Handle("/health/ready", middleware.Instrument(ready)) + g.Go(func() error { + logger.Infof("Starting control http server on %v", hconfig.ControlListenAddr) + return http.ListenAndServe(hconfig.ControlListenAddr, controlMux) + }) + + sgx := enclave.SGXEnclave() + logger.Infof("creating enclave") + if err := sgx.Init(enclavePath, econfig); err != nil { + logger.Fatalf("creating enclave: %v", err) + } + defer sgx.Close() + c, nodeID := sgx.OutputMessages(), sgx.PID() + + logger.WithGlobal(zap.String("me", nodeID.String())) + logger.Infow("created enclave") + + txGen := &util.TxGenerator{} + + dispatcher := dispatch.New(hconfig.Raft, txGen, sgx, c) + + // listen for peer network requests + ln, err := net.Listen("tcp", hconfig.PeerAddr) + if err != nil { + logger.Fatalf("failed to net.listen: %v", err) + } + peerServer := peer.NewPeerServer(ctx, nodeID, dispatcher) + g.Go(func() error { return peerServer.Listen(ln) }) + + logger.Infof("started peer server on %v", ln.Addr()) + + peerDB := peerdb.New(hconfig.Redis) + + // let other peers look us up by our nodeID + insertCtx, insertCancel := context.WithTimeout(ctx, time.Minute) + if err := peerDB.Insert(insertCtx, nodeID, hconfig.PeerAddr, hconfig.InitialRedisPeerDBTTL); err != nil { + logger.Fatalf("failed to update peerdb : %v", err) + } + insertCancel() + + logger.Infof("built peer lookup") + + // create network senders + peerClient := peer.NewPeerClient(ctx, nodeID, peerDB, &hconfig.Peer) + g.Go(func() error { return peerClient.Run() }) + + // now that everything's wired up, start processing enclave requests + g.Go(func() error { return dispatcher.Run(ctx, peerClient) }) + + rateLimiter := rate.NewRedisLimiter(hconfig) + + // set up http server + clientMux := http.NewServeMux() + clientMux.Handle(fmt.Sprintf("/v1/%s", hconfig.EnclaveID), + middleware.Instrument(middleware.AuthCheck(authenticator, + middleware.RateLimit(rateLimiter, handlers.NewWebsocket(&hconfig.Request, dispatcher))))) + clientMux.Handle("/v1/delete", + middleware.Instrument(middleware.AuthCheck(authenticator, + middleware.RateLimit(rateLimiter, handlers.NewDeleteBackup(dispatcher))))) + + // control endpoints + controlMux.Handle("/control/loglevel", middleware.Instrument(handlers.NewSetLogLevel(hconfig, dispatcher))) + controlMux.Handle("/control", middleware.Instrument(handlers.NewControl(dispatcher))) + + g.Go(func() error { + logger.Infof("Starting client http server on %v", hconfig.ClientListenAddr) + return http.ListenAndServe(hconfig.ClientListenAddr, clientMux) + }) + + // wait until we successfully create a raft group or join an existing one + raftManager := raftmanager.New(nodeID, dispatcher, peerDB, hconfig) + joinCtx, joinCancel := context.WithTimeout(ctx, time.Minute) + if err := raftManager.CreateOrJoin(joinCtx); err != nil { + logger.Fatalf("failure to join raft : %v", err) + } + joinCancel() + + // successfully joined raft, periodically refresh our peerdb status + g.Go(func() error { + return raftManager.RunRefresher(ctx, func(innerCtx context.Context) error { + timeoutCtx, cancel := context.WithTimeout(innerCtx, time.Minute) + defer cancel() + return peerDB.JoinedRaft(timeoutCtx, nodeID, hconfig.PeerAddr, hconfig.RecurringRedisPeerDBTTL) + }) + }) + + // fully capable of servicing user requests, mark ready + live.Set(nil) + ready.Set(nil) + + sigtermC := make(chan os.Signal, 1) + signal.Notify(sigtermC, os.Signal(syscall.SIGTERM)) + g.Go(func() error { + select { + case <-sigtermC: + logger.Errorf("Received SIGTERM, gracefully shutting down") + // If we're the leader, stop being the leader. + logger.Errorf("Relinquishing Raft leadership") + dispatcher.SendTransaction(&pb.HostToEnclaveRequest{ + Inner: &pb.HostToEnclaveRequest_RelinquishLeadership{RelinquishLeadership: true}, + }) + logger.Errorf("Requesting removal from Raft") + dispatcher.SendTransaction(&pb.HostToEnclaveRequest{ + Inner: &pb.HostToEnclaveRequest_RequestRemoval{RequestRemoval: true}, + }) + logger.Errorf("Done gracefully shutting down, exiting") + return errors.New("SIGTERM") + case <-ctx.Done(): + return ctx.Err() + } + }) + + return g.Wait() +} diff --git a/host/service/service_test.go b/host/service/service_test.go new file mode 100644 index 0000000..db43f25 --- /dev/null +++ b/host/service/service_test.go @@ -0,0 +1,281 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package service + +import ( + "bytes" + "context" + "encoding/base64" + "errors" + "fmt" + "io/ioutil" + "net" + "net/http" + "net/url" + "testing" + "time" + + "google.golang.org/protobuf/proto" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + "github.com/alicebob/miniredis/v2" + "github.com/gorilla/websocket" + "github.com/signalapp/svr2/auth" + "github.com/signalapp/svr2/config" + "github.com/signalapp/svr2/logger" + "github.com/signalapp/svr2/servicetest" + "github.com/signalapp/svr2/web/client" + + pb "github.com/signalapp/svr2/proto" +) + +func waitForReady(t *testing.T, controlAddr string, timeout time.Duration) { + url := fmt.Sprintf("http://%v/health/ready", controlAddr) + if err := servicetest.WaitFor200(timeout, url); err != nil { + t.Fatal(err) + } +} + +func randomPort(t *testing.T) int { + listener, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + port := listener.Addr().(*net.TCPAddr).Port + if err := listener.Close(); err != nil { + t.Fatal(err) + } + return port + +} + +func dial(t *testing.T, cfg *config.Config) *websocket.Conn { + u := url.URL{Scheme: "ws", Host: cfg.ClientListenAddr, Path: "v1/enclave"} + hdrs := http.Header{} + hdrs.Add("Authorization", fmt.Sprintf("Basic %s", base64.URLEncoding.EncodeToString([]byte("00112233445566778899aabbccddeeff:foo")))) + c, _, err := websocket.DefaultDialer.Dial(u.String(), hdrs) + if err != nil { + t.Fatalf("dial: %v", err) + } + t.Cleanup(func() { c.Close() }) + return c +} + +func backup(t *testing.T, tc *servicetest.TestClient, pin []byte, data []byte) { + // send a backup request + defer tc.Sc.Close() + r := tc.Send(&pb.Request{Inner: &pb.Request_Backup{ + Backup: &pb.BackupRequest{ + Data: data, + Pin: pin, + MaxTries: 5, + }, + }}) + if br, ok := r.Inner.(*pb.Response_Backup); !ok { + t.Fatalf("Unexpected response to backup: %v", r) + } else if br.Backup.Status != pb.BackupResponse_OK { + t.Fatalf("Incorrect response: %v", br) + } +} + +func expose(t *testing.T, tc *servicetest.TestClient, data []byte) { + defer tc.Sc.Close() + // send a expose request + r := tc.Send(&pb.Request{Inner: &pb.Request_Expose{ + Expose: &pb.ExposeRequest{ + Data: data, + }, + }}) + if br, ok := r.Inner.(*pb.Response_Expose); !ok { + t.Fatalf("Unexpected response to backup: %v", r) + } else if br.Expose.Status != pb.ExposeResponse_OK { + t.Fatalf("Incorrect response: %v", br) + } +} + +func restore(t *testing.T, tc *servicetest.TestClient, pin []byte, expectedData []byte) { + defer tc.Sc.Close() + // send a restore request + r := tc.Send(&pb.Request{Inner: &pb.Request_Restore{ + Restore: &pb.RestoreRequest{ + Pin: pin, + }, + }}) + if rr, ok := r.Inner.(*pb.Response_Restore); !ok { + t.Fatalf("Unexpected response to restore: %v", r) + } else if rr.Restore.Status != pb.RestoreResponse_OK { + t.Fatalf("Incorrect response: %v", rr) + } else if !bytes.Equal(rr.Restore.Data, expectedData) { + t.Fatalf("Restored bytes %v, want %v", rr.Restore.Data, expectedData) + } +} + +func startService(t *testing.T) *config.Config { + redis := miniredis.RunT(t) + + hconfig := config.Default() + hconfig.ClientListenAddr = fmt.Sprintf("localhost:%v", randomPort(t)) + hconfig.ControlListenAddr = fmt.Sprintf("localhost:%v", randomPort(t)) + hconfig.PeerAddr = fmt.Sprintf("localhost:%v", randomPort(t)) + hconfig.Redis.Addrs = []string{redis.Addr()} + + econfig := pb.InitConfig{ + EnclaveConfig: &pb.EnclaveConfig{ + E2ETxnTimeoutTicks: 30, + Raft: &pb.RaftConfig{ + ElectionTicks: 30, + HeartbeatTicks: 15, + ReplicationChunkBytes: 1048576, + ReplicaVotingTimeoutTicks: 60, + ReplicaMembershipTimeoutTicks: 300, + LogMaxBytes: 10 << 20, + }, + }, + GroupConfig: &pb.RaftGroupConfig{ + DbVersion: pb.DatabaseVersion_DATABASE_VERSION_SVR2, + MinVotingReplicas: 1, + MaxVotingReplicas: 5, + AttestationTimeout: 3600, + Simulated: true, + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + go func() { Start(ctx, &econfig, hconfig, "../../enclave/build/enclave.test", auth.AlwaysAllow) }() + waitForReady(t, hconfig.ControlListenAddr, time.Minute) + return hconfig +} + +func TestService(t *testing.T) { + // The enclave can only be initialized once, so only run setup once + cfg := startService(t) + t.Run("TestSetLogLevel", func(t *testing.T) { + testSetLogLevel(t, cfg) + }) + t.Run("TestBackupRestore", func(t *testing.T) { + testBackupRestore(t, cfg) + }) + t.Run("TestBadArgs", func(t *testing.T) { + testBadArgs(t, cfg) + }) + t.Run("TestKeyRotation", func(t *testing.T) { + testKeyRotation(t, cfg) + }) +} + +func testSetLogLevel(t *testing.T, cfg *config.Config) { + orig := cfg.Log.Level.Level() + cfg.Log.Level.SetLevel(zapcore.ErrorLevel) + logger.Init(cfg) + + isEnabled := func(level zapcore.Level) bool { + return zap.L().Check(level, "test") != nil + } + + if !isEnabled(zapcore.ErrorLevel) { + t.Errorf("isEnabled(Error) = false, want true") + } + + if isEnabled(zapcore.InfoLevel) { + t.Errorf("isEnabled(Info) = true, want false") + } + + // use the server http endpoint to delete the backup for user + controlURL := fmt.Sprintf("http://%v/control/loglevel", cfg.ControlListenAddr) + resp, err := http.PostForm(controlURL, url.Values{"level": []string{"info"}}) + if err != nil { + t.Fatal(err) + } + defer http.PostForm(controlURL, url.Values{"level": []string{orig.String()}}) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bs, _ := ioutil.ReadAll(resp.Body) + t.Errorf("PUT loglevel = %v:%s, want %v", resp.Status, bs, http.StatusOK) + } + + if !isEnabled(zapcore.InfoLevel) { + t.Errorf("isEnabled(Info) = false, want true") + } + +} + +func testBackupRestore(t *testing.T, cfg *config.Config) { + + // send a backup request + pin := servicetest.RandBytes(t, 32) + data := servicetest.RandBytes(t, 40) + tc := servicetest.NewTestClient(t, dial(t, cfg)) + backup(t, tc, pin, data) + + tc = servicetest.NewTestClient(t, dial(t, cfg)) + expose(t, tc, data) + + // send a restore request + tc = servicetest.NewTestClient(t, dial(t, cfg)) + restore(t, tc, pin, data) +} + +func testBadArgs(t *testing.T, cfg *config.Config) { + c := dial(t, cfg) + + // read the handshake start message + _, msg, err := c.ReadMessage() + if err != nil { + t.Error(err) + } + + var start pb.ClientHandshakeStart + if err := proto.Unmarshal(msg, &start); err != nil { + t.Error(err) + } + + // send garbage + if err := c.WriteMessage(websocket.BinaryMessage, []byte{1}); err != nil { + t.Error(err) + } + + _, _, err = c.ReadMessage() + var wsErr *websocket.CloseError + if !errors.As(err, &wsErr) { + t.Fatalf("expected close frame error, got err = %v", err) + } + if wsErr.Code != 4003 { + t.Fatalf("bad handshake got close code %v, want %v", wsErr.Code, 4003) + } +} + +func testKeyRotation(t *testing.T, cfg *config.Config) { + + // handshake with old public key + tc := servicetest.NewTestClient(t, dial(t, cfg)) + + // force a rekey + cc := client.ControlClient{Addr: cfg.ControlListenAddr} + refreshReq := pb.HostToEnclaveRequest{ + Inner: &pb.HostToEnclaveRequest_RefreshAttestation{ + RefreshAttestation: &pb.RefreshAttestation{RotateClientKey: true}, + }, + } + if resp, err := cc.Do(&refreshReq); err != nil { + t.Fatal(err) + } else if resp.GetStatus() != pb.Error_OK { + t.Fatalf("RefreshAttestation(true)=%v, want=%v", resp.GetStatus(), pb.Error_OK) + } + + // should be able to do a backup even after rekey + pin := servicetest.RandBytes(t, 32) + data := servicetest.RandBytes(t, 40) + backup(t, tc, pin, data) + + oldkey := tc.Sc.PubKey + tc = servicetest.NewTestClient(t, dial(t, cfg)) + if bytes.Equal(tc.Sc.PubKey, oldkey) { + t.Fatalf("handshake key should not match after key rotation") + } + tc.Send(&pb.Request{Inner: &pb.Request_Delete{Delete: &pb.DeleteRequest{}}}) +} diff --git a/host/servicetest/servicetest.go b/host/servicetest/servicetest.go new file mode 100644 index 0000000..4eaadae --- /dev/null +++ b/host/servicetest/servicetest.go @@ -0,0 +1,81 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package servicetest + +import ( + "crypto/rand" + "fmt" + "io" + "net/http" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/signalapp/svr2/util" + "github.com/signalapp/svr2/web/client" + + pb "github.com/signalapp/svr2/proto" +) + +type TestClient struct { + t *testing.T + Sc *client.SVR2Client +} + +func (tc *TestClient) Send(req *pb.Request) *pb.Response { + res, err := tc.Sc.Send(req) + if err != nil { + tc.t.Fatalf(err.Error()) + } + return res +} + +func NewTestClient(t *testing.T, c *websocket.Conn) *TestClient { + client, err := client.NewClient(c) + if err != nil { + t.Fatal(err) + } + return &TestClient{t, client} +} + +func RandBytes(t *testing.T, count uint32) []byte { + bs := make([]byte, count) + if _, err := rand.Read(bs); err != nil { + t.Fatalf("rand: %v", err) + } + return bs +} + +func RetryFun[T any](timeout time.Duration, fun func() (T, error)) (T, error) { + timech := time.After(timeout) + var err error + var res T + for { + select { + case <-timech: + return res, fmt.Errorf("timeout: %w", err) + default: + if res, err = fun(); err == nil { + return res, nil + } + time.Sleep(util.Min(time.Second, timeout/10)) + } + } +} + +func WaitFor200(timeout time.Duration, url string) error { + _, err := RetryFun(timeout, func() (interface{}, error) { + resp, err := http.Get(url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("status=%v : %v", resp.Status, body) + } + return nil, nil + }) + return err +} diff --git a/host/util/clock.go b/host/util/clock.go new file mode 100644 index 0000000..aab2ec3 --- /dev/null +++ b/host/util/clock.go @@ -0,0 +1,29 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package util + +import ( + "time" +) + +// Clock provides an interface for accessing the current time. +type Clock interface { + Now() time.Time +} + +type realClock struct{} + +func (r *realClock) Now() time.Time { + return time.Now() +} + +// RealClock returns a clock which uses time.Now. +var RealClock Clock = (*realClock)(nil) + +// TestAt is a Clock that returns a set single point in time. +type TestAt time.Time + +func (t TestAt) Now() time.Time { + return time.Time(t) +} diff --git a/host/util/clock_test.go b/host/util/clock_test.go new file mode 100644 index 0000000..291572f --- /dev/null +++ b/host/util/clock_test.go @@ -0,0 +1,45 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package util + +import ( + "fmt" + "testing" + "time" +) + +func TestTestAt(t *testing.T) { + now := time.Now() + ta := TestAt(now) + if got := ta.Now(); got != now { + t.Errorf("TestAt.Now: got %v want %v", got, now) + } +} + +func TestRealClock(t *testing.T) { + t1 := time.Now() + time.Sleep(time.Millisecond * 10) + v1 := RealClock.Now() + time.Sleep(time.Millisecond * 10) + t2 := time.Now() + time.Sleep(time.Millisecond * 10) + v2 := RealClock.Now() + if !v1.After(t1) { + t.Errorf("v1 (%v) before t1 (%v)", v1, t1) + } + if !t2.After(v1) { + t.Errorf("t2 (%v) before v1 (%v)", t2, v1) + } + if !v2.After(t2) { + t.Errorf("v2 (%v) before t2 (%v)", v2, t2) + } +} + +func ExampleTestAt() { + clock := TestAt(time.Unix(123, 0)) + now := clock.Now() + fmt.Print(now.Unix()) + // Output: + // 123 +} diff --git a/host/util/txid.go b/host/util/txid.go new file mode 100644 index 0000000..33ad0d8 --- /dev/null +++ b/host/util/txid.go @@ -0,0 +1,17 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package util + +import "sync/atomic" + +// TxGenerator provides unique transaction ids for transactional request/responses +// to the enclave +type TxGenerator struct { + txcounter uint64 +} + +// NextId returns a new unique transaction id +func (t *TxGenerator) NextID() uint64 { + return atomic.AddUint64(&t.txcounter, 1) +} diff --git a/host/util/user_agent.go b/host/util/user_agent.go new file mode 100644 index 0000000..a42e0b8 --- /dev/null +++ b/host/util/user_agent.go @@ -0,0 +1,86 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package util + +import ( + "errors" + "regexp" + "strings" + + "github.com/armon/go-metrics" +) + +type Platform string +type AllowedVersions = map[Platform]map[string]bool + +const ( + PlatformIOS = "ios" + PlatformDesktop = "desktop" + PlatformAndroid = "android" +) + +var uaPattern = regexp.MustCompile("(?i)^Signal-(?PAndroid|Desktop|iOS)/(?P[^ ]+)( (?P.+))?$") + +type UserAgent struct { + Platform Platform + Version string + AdditionalSpecifiers string +} + +// Tags returns metric labels that can be used to identify this UserAgent. By default, +// version tags will not be included, however if there are versions of interest an +// allow list may be provided via allowedVersions. +func (ua UserAgent) Tags(allowedVersions AllowedVersions) []metrics.Label { + labels := []metrics.Label{ + {Name: "platform", Value: string(ua.Platform)}, + } + if versions, ok := allowedVersions[ua.Platform]; ok && versions[ua.Version] { + labels = append(labels, metrics.Label{Name: "clientVersion", Value: ua.Version}) + } + return labels +} + +// ParseTags parses a UserAgent string and returns tags that can be used to identify +// the user agent. By default, version tags will not be included, see ParseTagsIncludeVersions +func ParseTags(uaString string) []metrics.Label { + return ParseTagsIncludeVersions(uaString, nil) +} + +// ParseTagsIncludeVersions parses a UserAgent string and returns tags that can be used to identify +// the user agent. Version tags will be included for those specified in AllowedVersions +func ParseTagsIncludeVersions(uaString string, allowedVersions AllowedVersions) []metrics.Label { + ua, err := ParseUserAgent(uaString) + if err != nil { + return []metrics.Label{{Name: "platform", Value: "unknown"}} + } + return ua.Tags(allowedVersions) +} + +// ParseUserAgent parses a user agent string provided via a User-Agent header +func ParseUserAgent(uaString string) (*UserAgent, error) { + if len(uaString) == 0 { + return nil, errors.New("User-Agent string is blank") + } + matches := uaPattern.FindStringSubmatch(uaString) + if len(matches) == 0 { + return nil, errors.New("unrecognized user agent") + } + platform := parsePlatform(matches[uaPattern.SubexpIndex("Platform")]) + version := matches[uaPattern.SubexpIndex("Version")] + specifiers := strings.TrimSpace(matches[uaPattern.SubexpIndex("Specifiers")]) + return &UserAgent{platform, version, specifiers}, nil +} + +func parsePlatform(platform string) Platform { + switch strings.ToLower(platform) { + case PlatformDesktop: + return PlatformDesktop + case PlatformIOS: + return PlatformIOS + case PlatformAndroid: + return PlatformAndroid + } + // should never happen - platform should be already be validated by `uaPattern` + return "unexpected-parse-failure" +} diff --git a/host/util/user_agent_test.go b/host/util/user_agent_test.go new file mode 100644 index 0000000..d0d63f9 --- /dev/null +++ b/host/util/user_agent_test.go @@ -0,0 +1,82 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package util + +import ( + "reflect" + "testing" + + "github.com/armon/go-metrics" +) + +func TestParseUserAgent(t *testing.T) { + tests := []struct { + userAgent string + parsed *UserAgent + }{ + {"Signal-Android/4.68.3 Android/25", &UserAgent{PlatformAndroid, "4.68.3", "Android/25"}}, + {"This is obviously not a reasonable User-Agent string.", nil}, + {"Signal-Android/4.68.3", &UserAgent{PlatformAndroid, "4.68.3", ""}}, + {"Signal-Desktop/1.2.3 Linux", &UserAgent{PlatformDesktop, "1.2.3", "Linux"}}, + {"Signal-Desktop/1.2.3 macOS", &UserAgent{PlatformDesktop, "1.2.3", "macOS"}}, + {"Signal-Desktop/1.2.3 Windows", &UserAgent{PlatformDesktop, "1.2.3", "Windows"}}, + {"Signal-Desktop/1.2.3", &UserAgent{PlatformDesktop, "1.2.3", ""}}, + {"Signal-Desktop/1.32.0-beta.3", &UserAgent{PlatformDesktop, "1.32.0-beta.3", ""}}, + {"Signal-iOS/3.9.0 (iPhone; iOS 12.2; Scale/3.00)", &UserAgent{PlatformIOS, "3.9.0", "(iPhone; iOS 12.2; Scale/3.00)"}}, + {"Signal-iOS/3.9.0 iOS/14.2", &UserAgent{PlatformIOS, "3.9.0", "iOS/14.2"}}, + {"Signal-iOS/3.9.0", &UserAgent{PlatformIOS, "3.9.0", ""}}, + } + for _, tt := range tests { + t.Run(tt.userAgent, func(t *testing.T) { + parsed, _ := ParseUserAgent(tt.userAgent) + if !reflect.DeepEqual(parsed, tt.parsed) { + t.Errorf("ParseUserAgent(%v) = %v, want %v", tt.userAgent, parsed, tt.parsed) + } + }) + } +} +func TestParseTags(t *testing.T) { + tests := []struct { + userAgent string + allowedVersions []string + tags []metrics.Label + }{ + { + "Signal-Android/4.68.3 Android/25", + []string{"4.68.3"}, + []metrics.Label{{Name: "platform", Value: "android"}, {Name: "clientVersion", Value: "4.68.3"}}, + }, + { + "Signal-Android/4.68.3 Android/25", + []string{"4.68.4", "4.68.3"}, + []metrics.Label{{Name: "platform", Value: "android"}, {Name: "clientVersion", Value: "4.68.3"}}, + }, + + { + "Signal-Android/4.68.3 Android/25", + []string{"4.68.4"}, + []metrics.Label{{Name: "platform", Value: "android"}}, + }, + + { + "A bad useragent Signal-Android/4.68.3 Android/25", + []string{"4.68.3"}, + []metrics.Label{{Name: "platform", Value: "unknown"}}, + }, + } + + for _, tt := range tests { + t.Run(tt.userAgent, func(t *testing.T) { + m := make(map[string]bool) + for _, version := range tt.allowedVersions { + m[version] = true + } + tags := ParseTagsIncludeVersions(tt.userAgent, AllowedVersions{PlatformAndroid: m}) + if !reflect.DeepEqual(tags, tt.tags) { + t.Errorf("ParseTags(%v) = %v, want %v", tt.userAgent, tags, tt.tags) + } + }) + } + +} diff --git a/host/util/util.go b/host/util/util.go new file mode 100644 index 0000000..8a28b56 --- /dev/null +++ b/host/util/util.go @@ -0,0 +1,58 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +// Package util contains general purpose utilities +package util + +import ( + "context" + "time" + + "golang.org/x/exp/constraints" +) + +// Min returns the minimum of a and b +func Min[T constraints.Ordered](a, b T) T { + if a < b { + return a + } + return b +} + +// Max returns the maximum of a and b +func Max[T constraints.Ordered](a, b T) T { + if a > b { + return a + } + return b +} + +// Clamp restricts the value to the range [lo, hi] +func Clamp[T constraints.Ordered](v, lo, hi T) T { + return Max(lo, Min(hi, v)) +} + +// RetryWithBackoff repeatedly attempts to call `fun`, increasing the wait between each call +func RetryWithBackoff(ctx context.Context, fun func() error, minSleep time.Duration, maxSleep time.Duration) error { + _, err := RetrySupplierWithBackoff(ctx, func() (interface{}, error) { return nil, fun() }, minSleep, maxSleep) + return err +} + +// RetrySupplierWithBackoff repeatedly attempts to call `fun` to produce a value, increasing the wait between each call +func RetrySupplierWithBackoff[T any](ctx context.Context, fun func() (T, error), minSleep time.Duration, maxSleep time.Duration) (T, error) { + sleepTime := time.Duration(0) + var res T + var err error + for { + select { + case <-ctx.Done(): + return res, ctx.Err() + case <-time.After(sleepTime): + } + res, err = fun() + if err == nil { + return res, nil + } + sleepTime = Clamp(sleepTime*2, minSleep, maxSleep) + } +} diff --git a/host/web/client/control_client.go b/host/web/client/control_client.go new file mode 100644 index 0000000..d421e06 --- /dev/null +++ b/host/web/client/control_client.go @@ -0,0 +1,57 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package client + +import ( + "bytes" + "fmt" + "io" + "net/http" + + "google.golang.org/protobuf/encoding/protojson" + + pb "github.com/signalapp/svr2/proto" +) + +type ControlClient struct { + Addr string +} + +func (cc *ControlClient) Do(request *pb.HostToEnclaveRequest) (*pb.HostToEnclaveResponse, error) { + bs, err := protojson.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal proto : %w", err) + } + return cc.DoJSON(bs) +} + +func (cc *ControlClient) DoJSON(request []byte) (*pb.HostToEnclaveResponse, error) { + url := fmt.Sprintf("http://%v/control", cc.Addr) + req, err := http.NewRequest(http.MethodPut, url, bytes.NewBuffer(request)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("request failed : %w", err) + } + + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body : %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("request failed, status=%v, body=%s", resp.Status, body) + } + + pbResponse := pb.HostToEnclaveResponse{} + if err := protojson.Unmarshal(body, &pbResponse); err != nil { + return nil, fmt.Errorf("could not parse server response, body=%s : %w", body, err) + } + return &pbResponse, nil +} diff --git a/host/web/client/svr2client.go b/host/web/client/svr2client.go new file mode 100644 index 0000000..bb71e9b --- /dev/null +++ b/host/web/client/svr2client.go @@ -0,0 +1,99 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +// Package client provides client implementations for SVR2Client endpoints +package client + +import ( + "fmt" + + "github.com/flynn/noise" + "github.com/gorilla/websocket" + "google.golang.org/protobuf/proto" + + pb "github.com/signalapp/svr2/proto" +) + +type SVR2Client struct { + c *websocket.Conn + encrypt *noise.CipherState + decrypt *noise.CipherState + PubKey []byte +} + +func (sc *SVR2Client) Close() { + sc.c.Close() +} + +func (sc *SVR2Client) Send(req *pb.Request) (*pb.Response, error) { + bs, err := proto.Marshal(req) + if err != nil { + return nil, fmt.Errorf("marshal: %w", err) + } + + var ciphertext []byte + if ciphertext, err = sc.encrypt.Encrypt(ciphertext, nil, bs); err != nil { + return nil, fmt.Errorf("encrypt: %w", err) + } + if err = sc.c.WriteMessage(websocket.BinaryMessage, ciphertext); err != nil { + return nil, fmt.Errorf("writews: %w", err) + } + + _, msg, err := sc.c.ReadMessage() + if err != nil { + return nil, fmt.Errorf("readws: %w", err) + } + + var plaintext []byte + if plaintext, err = sc.decrypt.Decrypt(plaintext, nil, msg); err != nil { + return nil, fmt.Errorf("decrypt: %w", err) + } + var m pb.Response + if err := proto.Unmarshal(plaintext, &m); err != nil { + return nil, fmt.Errorf("unmarshal: %w", err) + } + return &m, nil +} + +func NewClient(c *websocket.Conn) (*SVR2Client, error) { + // extract the server public key + _, msg, err := c.ReadMessage() + if err != nil { + return nil, fmt.Errorf("readws: %w", err) + } + + var start pb.ClientHandshakeStart + if err := proto.Unmarshal(msg, &start); err != nil { + return nil, fmt.Errorf("unmarshal: %w", err) + } + + // start a noise handshake + hs, err := noise.NewHandshakeState(noise.Config{ + CipherSuite: noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256), + Pattern: noise.HandshakeNK, + Initiator: true, + PeerStatic: start.TestOnlyPubkey, + }) + if err != nil { + return nil, fmt.Errorf("hanshake init: %w", err) + } + var out []byte + out, _, _, err = hs.WriteMessage(out, nil) + if err != nil { + return nil, fmt.Errorf("handshake: %w", err) + } + + if err = c.WriteMessage(websocket.BinaryMessage, out); err != nil { + return nil, fmt.Errorf("writews: %w", err) + } + + if _, msg, err = c.ReadMessage(); err != nil { + return nil, fmt.Errorf("handshake readws: %w", err) + } + + _, encrypt, decrypt, err := hs.ReadMessage(nil, msg) + if err != nil { + return nil, fmt.Errorf("handshake read: %w", err) + } + return &SVR2Client{c, encrypt, decrypt, start.TestOnlyPubkey}, nil +} diff --git a/host/web/handlers/control.go b/host/web/handlers/control.go new file mode 100644 index 0000000..e48a5c8 --- /dev/null +++ b/host/web/handlers/control.go @@ -0,0 +1,79 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package handlers + +import ( + "fmt" + "io" + "mime" + "net/http" + + "github.com/signalapp/svr2/logger" + "google.golang.org/protobuf/encoding/protojson" + + pb "github.com/signalapp/svr2/proto" +) + +// NewControl returns a handler that takes HTTP PUT requests with a [pb.HostToEnclaveRequest] +// and forwards them to the enclave, returning the enclave's response +func NewControl(server EnclaveRequester) http.Handler { + return &controlHandler{enclaveRequester: server} +} + +type controlHandler struct { + enclaveRequester EnclaveRequester +} + +func (c *controlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut { + http.NotFound(w, r) + return + } + + contentType := r.Header.Get("Content-Type") + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil || mediaType != "application/json" { + http.Error(w, fmt.Sprintf("invalid content type %v: %v", err, mediaType), http.StatusBadRequest) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + req := &pb.HostToEnclaveRequest{} + if err := protojson.Unmarshal(body, req); err != nil { + http.Error(w, fmt.Sprintf("invalid request proto : %v", err), http.StatusBadRequest) + return + } + + if req.RequestId != 0 { + logger.Warnf("control set request id %v, it will be ignored", req.RequestId) + req.RequestId = 0 + } + + resp, err := c.enclaveRequester.SendTransaction(req) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if err = responseErr(resp); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + + } + + out, err := protojson.Marshal(resp) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + if _, err := w.Write(out); err != nil { + logger.Warnw("error writing control response", "err", err) + } +} diff --git a/host/web/handlers/delete_backup.go b/host/web/handlers/delete_backup.go new file mode 100644 index 0000000..0712e28 --- /dev/null +++ b/host/web/handlers/delete_backup.go @@ -0,0 +1,66 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package handlers + +import ( + "encoding/hex" + "net/http" + + "google.golang.org/protobuf/proto" + + pb "github.com/signalapp/svr2/proto" +) + +// NewDeleteBackup returns a handler that takes HTTP DELETE requests and notifies +// the enclave to delete the backup associated with the user (provided via basic auth) +func NewDeleteBackup(server EnclaveRequester) http.Handler { + return &deleteBackupHandler{ + enclaveRequester: server, + } +} + +type deleteBackupHandler struct { + enclaveRequester EnclaveRequester +} + +func (d *deleteBackupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete { + http.NotFound(w, r) + return + } + user, _, _ := r.BasicAuth() + authID, err := hex.DecodeString(user) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } else if len(authID) != 16 { + http.Error(w, "auth ID not 16 bytes", http.StatusBadRequest) + return + } + deleteReq := pb.Request{ + Inner: &pb.Request_Delete{Delete: &pb.DeleteRequest{}}, + } + marshalled, err := proto.Marshal(&deleteReq) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + resp, err := d.enclaveRequester.SendTransaction(&pb.HostToEnclaveRequest{ + Inner: &pb.HostToEnclaveRequest_DatabaseRequest{ + DatabaseRequest: &pb.DatabaseRequest{ + Request: marshalled, + AuthenticatedId: authID, + }, + }, + }) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if err = responseErr(resp); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} diff --git a/host/web/handlers/handlers.go b/host/web/handlers/handlers.go new file mode 100644 index 0000000..ebd8276 --- /dev/null +++ b/host/web/handlers/handlers.go @@ -0,0 +1,24 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +// package handlers provides http handlers for SVR2 endpoints +package handlers + +import ( + "fmt" + pb "github.com/signalapp/svr2/proto" +) + +type EnclaveRequester interface { + // SendTransaction sends a request to the enclave and returns the enclave's response. + // Implementations must tag the provided request with a requestID. + SendTransaction(req *pb.HostToEnclaveRequest) (*pb.HostToEnclaveResponse, error) +} + +// responseErr check a response for an enclave error status +func responseErr(r *pb.HostToEnclaveResponse) error { + if e, ok := r.Inner.(*pb.HostToEnclaveResponse_Status); ok && e.Status != pb.Error_OK { + return fmt.Errorf("transaction %d failed with code: %w", r.RequestId, e.Status) + } + return nil +} diff --git a/host/web/handlers/set_log_level.go b/host/web/handlers/set_log_level.go new file mode 100644 index 0000000..2cd2246 --- /dev/null +++ b/host/web/handlers/set_log_level.go @@ -0,0 +1,92 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package handlers + +import ( + "errors" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/signalapp/svr2/config" + "github.com/signalapp/svr2/logger" + "go.uber.org/zap/zapcore" + + pb "github.com/signalapp/svr2/proto" +) + +// NewSetLogLevel returns a handler that takes requests to dynamically configure +// the log level. The desired log level should be provided in a POST request with +// "Content-Type: application/x-www-form-urlencoded" body, ex: level=DEBUG +func NewSetLogLevel(config *config.Config, enclaveRequester EnclaveRequester) http.Handler { + return &setLogLevelHandler{config, enclaveRequester} +} + +type setLogLevelHandler struct { + config *config.Config + enclaveRequester EnclaveRequester +} + +func (s *setLogLevelHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.NotFound(w, r) + return + } + if err := r.ParseForm(); err != nil { + http.Error(w, fmt.Sprintf("bad body: %v", err), http.StatusBadRequest) + return + } + + hostLevel, enclaveLevel, err := parseLogLevel(r.PostForm) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + // set the level on the host + s.config.Log.Level.SetLevel(hostLevel) + + // set the level on the enclave + resp, err := s.enclaveRequester.SendTransaction(&pb.HostToEnclaveRequest{ + Inner: &pb.HostToEnclaveRequest_SetLogLevel{ + SetLogLevel: enclaveLevel, + }, + }) + if err == nil { + err = responseErr(resp) + } + if err != nil { + logger.Errorw("failed to set enclave log level", "err", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + logger.Infof("successfully set log levels, host=%v enclave=%v", hostLevel, enclaveLevel) + w.WriteHeader(http.StatusOK) +} + +func parseLogLevel(values url.Values) (zapcore.Level, pb.EnclaveLogLevel, error) { + level := values.Get("level") + if level == "" { + return zapcore.InvalidLevel, pb.EnclaveLogLevel_LOG_LEVEL_NONE, errors.New("must provide log level") + } + + switch strings.TrimSpace(strings.ToUpper(level)) { + case "FATAL": + return zapcore.ErrorLevel, pb.EnclaveLogLevel_LOG_LEVEL_FATAL, nil + case "ERROR": + return zapcore.ErrorLevel, pb.EnclaveLogLevel_LOG_LEVEL_ERROR, nil + case "WARNING": + return zapcore.WarnLevel, pb.EnclaveLogLevel_LOG_LEVEL_WARNING, nil + case "INFO": + return zapcore.InfoLevel, pb.EnclaveLogLevel_LOG_LEVEL_INFO, nil + case "DEBUG": + return zapcore.DebugLevel, pb.EnclaveLogLevel_LOG_LEVEL_DEBUG, nil + case "VERBOSE": + return zapcore.DebugLevel, pb.EnclaveLogLevel_LOG_LEVEL_VERBOSE, nil + case "MAX": + return zapcore.DebugLevel, pb.EnclaveLogLevel_LOG_LEVEL_MAX, nil + } + return zapcore.InvalidLevel, pb.EnclaveLogLevel_LOG_LEVEL_NONE, fmt.Errorf("invalid log level %s", level) +} diff --git a/host/web/handlers/set_log_level_test.go b/host/web/handlers/set_log_level_test.go new file mode 100644 index 0000000..e88a687 --- /dev/null +++ b/host/web/handlers/set_log_level_test.go @@ -0,0 +1,110 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package handlers + +import ( + "errors" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/signalapp/svr2/config" + "go.uber.org/zap/zapcore" + + pb "github.com/signalapp/svr2/proto" +) + +type niceEnclave struct{} +type errorEnclave struct{} + +func (*niceEnclave) SendTransaction(p *pb.HostToEnclaveRequest) (*pb.HostToEnclaveResponse, error) { + return &pb.HostToEnclaveResponse{Inner: &pb.HostToEnclaveResponse_Status{ + Status: pb.Error_OK, + }}, nil +} +func (*errorEnclave) SendTransaction(p *pb.HostToEnclaveRequest) (*pb.HostToEnclaveResponse, error) { + return nil, errors.New("test") +} + +func TestEnclaveError(t *testing.T) { + cfg := config.Default() + mux := http.NewServeMux() + mux.Handle("/control/loglevel", NewSetLogLevel(cfg, &errorEnclave{})) + ts := httptest.NewServer(mux) + defer ts.Close() + + resp, err := http.PostForm(fmt.Sprintf("%v/control/loglevel", ts.URL), url.Values{ + "level": []string{"info"}, + }) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusInternalServerError { + t.Errorf("POST loglevel(name) = %v, want %v", resp.StatusCode, http.StatusInternalServerError) + } + +} + +func TestBadArgs(t *testing.T) { + cfg := config.Default() + mux := http.NewServeMux() + mux.Handle("/control/loglevel", NewSetLogLevel(cfg, &niceEnclave{})) + ts := httptest.NewServer(mux) + defer ts.Close() + + for _, tt := range []url.Values{ + {}, + {"level": []string{"foo"}}, + {"levelz": []string{"info"}}, + } { + name := fmt.Sprintf("%v", tt) + t.Run(name, func(t *testing.T) { + resp, err := http.PostForm(fmt.Sprintf("%v/control/loglevel", ts.URL), url.Values{}) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("POST loglevel(name) = %v, want %v", resp.StatusCode, http.StatusBadRequest) + } + + }) + } +} + +func TestParseLogLevel(t *testing.T) { + tests := []struct { + key string + level string + valid bool + expectedHost zapcore.Level + expectedEnclave pb.EnclaveLogLevel + }{ + {"level", "info", true, zapcore.InfoLevel, pb.EnclaveLogLevel_LOG_LEVEL_INFO}, + {"level", "InFo", true, zapcore.InfoLevel, pb.EnclaveLogLevel_LOG_LEVEL_INFO}, + {"level", "verbose", true, zapcore.DebugLevel, pb.EnclaveLogLevel_LOG_LEVEL_VERBOSE}, + {"levelz", "info", false, zapcore.InvalidLevel, pb.EnclaveLogLevel_LOG_LEVEL_NONE}, + {"LEVEL", "info", false, zapcore.InvalidLevel, pb.EnclaveLogLevel_LOG_LEVEL_NONE}, + } + for _, tt := range tests { + name := fmt.Sprintf("%s=%s", tt.key, tt.level) + t.Run(name, func(t *testing.T) { + val := url.Values{} + val.Set(tt.key, tt.level) + host, enclave, err := parseLogLevel(val) + if tt.valid && err != nil { + t.Errorf("expected success, got %v", err) + } + if host != tt.expectedHost { + t.Errorf("parseLogLevel(%s)=%v, want %v", name, host, tt.expectedHost) + } + if enclave != tt.expectedEnclave { + t.Errorf("parseLogLevel(%s)=%v, want %v", name, enclave, tt.expectedEnclave) + } + }) + } +} diff --git a/host/web/handlers/websocket.go b/host/web/handlers/websocket.go new file mode 100644 index 0000000..99c5bf7 --- /dev/null +++ b/host/web/handlers/websocket.go @@ -0,0 +1,228 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package handlers + +import ( + "encoding/hex" + "errors" + "fmt" + "net/http" + + "github.com/armon/go-metrics" + "github.com/gorilla/websocket" + "github.com/signalapp/svr2/config" + "github.com/signalapp/svr2/logger" + "github.com/signalapp/svr2/util" + "google.golang.org/protobuf/proto" + + pb "github.com/signalapp/svr2/proto" +) + +// NewWebsocket returns a handler that takes http GET requests and allows clients to +// upgrade to a websocket. They can then use the SVR2 protocol to perform a handshake +// and exchange encrypted messages with the enclave +func NewWebsocket(requestConfig *config.RequestConfig, enclaveRequester EnclaveRequester) http.Handler { + return &websocketHandler{ + clock: util.RealClock, + requestConfig: requestConfig, + upgrader: websocket.Upgrader{ + HandshakeTimeout: requestConfig.WebsocketHandshakeTimeout, + EnableCompression: false, + }, + enclaveRequester: enclaveRequester, + } +} + +const ( + maxReadLimit = 1024 * 128 +) + +var ( + enclaveErrorCounterName = []string{"websocket", "enclaveError"} + websocketClosureCounterName = []string{"websocket", "closeCode"} +) + +type websocketHandler struct { + clock util.Clock + requestConfig *config.RequestConfig + enclaveRequester EnclaveRequester + upgrader websocket.Upgrader +} + +// entrypoint for all app client requests +func (h *websocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + user, _, _ := r.BasicAuth() + // Get an ereport from the enclave + authID, err := hex.DecodeString(user) + if err != nil { + http.Error(w, "auth ID invalid hex", http.StatusBadRequest) + return + } else if len(authID) != 16 { + http.Error(w, "auth ID not 16 bytes", http.StatusBadRequest) + return + } + + c, err := h.upgrader.Upgrade(w, r, nil) + if err != nil { + logger.Warnw("ws upgrade failed", "err", err) + return + } + defer c.Close() + + c.SetReadLimit(maxReadLimit) + + err = h.handleClientWebsocket(c, authID) + var wsErr *websocket.CloseError + if errors.As(err, &wsErr) { + logger.Debugw("client close error", "err", wsErr) + labels := [1]metrics.Label{metrics.Label{Name: "code", Value: fmt.Sprintf("%d", wsErr.Code)}} + metrics.IncrCounterWithLabels(websocketClosureCounterName, 1, labels[:]) + return + } + + // Send a close frame and forget, no need to wait for close response + if err := h.writeMessage(c, websocket.CloseMessage, closeMessage(r, err)); err != nil { + logger.Infow("failed to write close message", "err", err) + } +} + +// Custom websocket close codes +// +// Application errors are [4000, 4015] +const ( + WSBadArgs = 4003 + WSInternalError = 4013 + WSUnavailable = 4014 +) + +// closeMessage builds a websocket closeMessage from an error. If the underlying error +// came from the enclave, it will be marshalled to the appropriate application error. +func closeMessage(r *http.Request, err error) []byte { + if err == nil { + return websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") + } + + var enclaveErr pb.Error + if !errors.As(err, &enclaveErr) { + logger.Warnw("error processing client request", "err", err) + return websocket.FormatCloseMessage(WSInternalError, "") + } + + // err might have supplemental error information. Only return enclaveErr in the + // close frame which only contains static enclave error codes + msg := enclaveErr.String() + labels := append(util.ParseTags(r.UserAgent()), metrics.Label{Name: "err", Value: msg}) + metrics.IncrCounterWithLabels(enclaveErrorCounterName, 1, labels) + + logger.Infow("error processing client request", "err", msg) + + switch enclaveErr { + case pb.Error_DB2_ClientDataSize, + pb.Error_DB2_ClientPinSize, + pb.Error_DB2_ClientTriesTooHigh, + pb.Error_DB2_ClientTriesZero: + // an issue with the client request outside the client error namespace + return websocket.FormatCloseMessage(WSBadArgs, msg) + + case pb.Error_Client_EncryptSerialize, + pb.Error_Client_TransactionInvalid, + pb.Error_Client_CopyDHState, + pb.Error_Client_AlreadyClosed: + // not the client's fault but in the "client" error namespace + return websocket.FormatCloseMessage(WSInternalError, msg) + + case pb.Error_Client_TransactionCancelled, + pb.Error_Core_LeaderUnknown: + // Transient non-serious errors that should just be retried. Guaranteed not to have modified any db state + return websocket.FormatCloseMessage(WSUnavailable, msg) + + default: + if pb.Error_Client_NS.Number() < enclaveErr.Number() && enclaveErr < pb.Error_Client_NS+100 { + // unknown error in the client namespace + return websocket.FormatCloseMessage(WSBadArgs, msg) + } + return websocket.FormatCloseMessage(WSInternalError, msg) + } +} + +func (h *websocketHandler) handleClientWebsocket(c *websocket.Conn, authID []byte) error { + response, err := h.enclaveRequester.SendTransaction(&pb.HostToEnclaveRequest{ + Inner: &pb.HostToEnclaveRequest_NewClient{NewClient: &pb.NewClientRequest{ + ClientAuthenticatedId: authID, + }}, + }) + if err != nil { + return err + } + + handshake, ok := response.Inner.(*pb.HostToEnclaveResponse_NewClientReply) + if !ok { + return errors.New("unexpected enclave proto") + } + clientID := handshake.NewClientReply.ClientId + // Defer cleanup of the client within the enclave once we're done with it. + // We fire-and-forget this, since it should always succeed on the enclave-side. + defer h.enclaveRequester.SendTransaction(&pb.HostToEnclaveRequest{ + Inner: &pb.HostToEnclaveRequest_CloseClient{CloseClient: &pb.CloseClientRequest{ + ClientId: clientID, + }}, + }) + bs, err := proto.Marshal(handshake.NewClientReply.HandshakeStart) + if err != nil { + return errors.New("failed to marshal handshake") + } + + // send ereport to client + if err := h.writeMessage(c, websocket.BinaryMessage, bs); err != nil { + return fmt.Errorf("wswrite: %v", err) + } + + // successfully returned an ereport. now just shunt + // opaque messages bytes between the client and enclave + for { + bs, err := h.readMessage(c) + if err != nil { + return fmt.Errorf("wsread: %w", err) + } + if bs == nil { + // websocket close? + return nil + } + reply, err := h.enclaveRequester.SendTransaction(&pb.HostToEnclaveRequest{ + Inner: &pb.HostToEnclaveRequest_ExistingClient{ExistingClient: &pb.ExistingClientRequest{ + ClientId: clientID, + Data: bs, + }}, + }) + if err != nil { + return err + } + if err = responseErr(reply); err != nil { + return err + } + payload, ok := reply.Inner.(*pb.HostToEnclaveResponse_ExistingClientReply) + if !ok { + return errors.New("unexpected enclave proto") + } + if len(payload.ExistingClientReply.Data) > 0 { + if err := h.writeMessage(c, websocket.BinaryMessage, payload.ExistingClientReply.Data); err != nil { + return fmt.Errorf("wswrite: %v", err) + } + } + if payload.ExistingClientReply.Fin { + return nil + } + } +} + +func (h *websocketHandler) readMessage(c *websocket.Conn) ([]byte, error) { + c.SetReadDeadline(h.clock.Now().Add(h.requestConfig.SocketTimeout)) + _, bs, err := c.ReadMessage() + return bs, err +} + +func (h *websocketHandler) writeMessage(c *websocket.Conn, messageType int, data []byte) error { + c.SetWriteDeadline(h.clock.Now().Add(h.requestConfig.SocketTimeout)) + return c.WriteMessage(messageType, data) +} diff --git a/host/web/middleware/auth.go b/host/web/middleware/auth.go new file mode 100644 index 0000000..2d18b6e --- /dev/null +++ b/host/web/middleware/auth.go @@ -0,0 +1,30 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package middleware + +import ( + "github.com/signalapp/svr2/auth" + "github.com/signalapp/svr2/logger" + "net/http" +) + +// AuthCheck wraps an http.Handler to check the request's BasicAuth using the provided authenticator +func AuthCheck(authenticator auth.Auth, inner http.Handler) http.Handler { + return &authHandler{authenticator, inner} +} + +type authHandler struct { + authenticator auth.Auth + inner http.Handler +} + +func (a *authHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + user, pass, _ := r.BasicAuth() + if err := a.authenticator.Check(user, pass); err != nil { + logger.Warnw("basic auth failed", "err", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + a.inner.ServeHTTP(w, r) +} diff --git a/host/web/middleware/metrics.go b/host/web/middleware/metrics.go new file mode 100644 index 0000000..ed8eb06 --- /dev/null +++ b/host/web/middleware/metrics.go @@ -0,0 +1,81 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package middleware + +import ( + "bufio" + "errors" + "net" + "net/http" + "strconv" + + "github.com/armon/go-metrics" + "github.com/signalapp/svr2/util" +) + +// Instrument wraps an http.Handler and updates metrics with the http response +func Instrument(inner http.Handler) http.Handler { + return &handler{inner: inner} +} + +type handler struct { + inner http.Handler +} + +func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ww := &writerWrapper{w: w} + h.inner.ServeHTTP(ww, r) + if ww.recorded { + userAgent := r.UserAgent() + labels := util.ParseTags(userAgent) + labels = append(labels, + metrics.Label{Name: "method", Value: r.Method}, + metrics.Label{Name: "endpoint", Value: r.URL.Path}, + metrics.Label{Name: "status", Value: strconv.Itoa(ww.statusCode)}, + ) + metrics.IncrCounterWithLabels([]string{"http", "response"}, 1, labels) + } +} + +// When a response is written, record the status code so it can be instrumented later +type writerWrapper struct { + w http.ResponseWriter + statusCode int + recorded bool +} + +var _ http.ResponseWriter = (*writerWrapper)(nil) +var _ http.Hijacker = (*writerWrapper)(nil) + +func (ww *writerWrapper) Header() http.Header { + return ww.w.Header() +} + +func (ww *writerWrapper) Write(b []byte) (int, error) { + if !ww.recorded { + ww.recorded = true + ww.statusCode = http.StatusOK + } + return ww.w.Write(b) +} + +func (ww *writerWrapper) WriteHeader(statusCode int) { + ww.recorded = true + ww.statusCode = statusCode + ww.w.WriteHeader(statusCode) +} + +func (ww *writerWrapper) Hijack() (net.Conn, *bufio.ReadWriter, error) { + h, ok := ww.w.(http.Hijacker) + if !ok { + return nil, nil, errors.New("hijack not supported") + } + if !ww.recorded { + // If the response handler is switching protocols, (e.x. upgrading + // to a websocket) report StatusSwitchingProtocols + ww.recorded = true + ww.statusCode = http.StatusSwitchingProtocols + } + return h.Hijack() +} diff --git a/host/web/middleware/rate.go b/host/web/middleware/rate.go new file mode 100644 index 0000000..5d3bf6e --- /dev/null +++ b/host/web/middleware/rate.go @@ -0,0 +1,53 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package middleware + +import ( + "errors" + "net/http" + "strconv" + + "github.com/armon/go-metrics" + "github.com/signalapp/svr2/logger" + "github.com/signalapp/svr2/rate" +) + +// RateLimit wraps a http.Handler and enforces a rate limit on requests going to that handler +func RateLimit(limiter rate.Limiter, next http.Handler) http.Handler { + return &rateLimitHandler{limiter, next} +} + +type rateLimitHandler struct { + limiter rate.Limiter + inner http.Handler +} + +var ( + rateLimitCounter = []string{"request", "rateLimit"} + rateLimitErrCounter = []string{"request", "rateLimitErr"} +) + +func (rh *rateLimitHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + user, _, _ := r.BasicAuth() + + err := rh.limiter.Limit(r.Context(), user) + var retryErr rate.ErrLimitExceeded + rateLimitExceeded := errors.As(err, &retryErr) + + metrics.IncrCounterWithLabels(rateLimitCounter, 1, []metrics.Label{ + {Name: "exceeded", Value: strconv.FormatBool(rateLimitExceeded)}, + }) + + if rateLimitExceeded { + retryAfterSecs := int64(retryErr.RetryAfter.Seconds()) + w.Header().Set("Retry-After", strconv.FormatInt(retryAfterSecs, 10)) + w.WriteHeader(http.StatusTooManyRequests) + return + } else if err != nil { + // still allow request in the case where we can't access the rate limiter + metrics.IncrCounter(rateLimitErrCounter, 1) + logger.Errorw("could not update rate limit", "err", err) + } + rh.inner.ServeHTTP(w, r) +} diff --git a/host/web/server_test.go b/host/web/server_test.go new file mode 100644 index 0000000..e195970 --- /dev/null +++ b/host/web/server_test.go @@ -0,0 +1,92 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +package web + +import ( + "context" + "encoding/base64" + "fmt" + "log" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/signalapp/svr2/config" + "github.com/signalapp/svr2/dispatch" + "github.com/signalapp/svr2/util" + "github.com/signalapp/svr2/web/handlers" + "google.golang.org/protobuf/proto" + + pb "github.com/signalapp/svr2/proto" +) + +type mockPeerSender struct{} + +func (*mockPeerSender) Send(*pb.PeerMessage) error { return nil } + +type mockEnclave struct { + ech chan *pb.EnclaveMessage + uch chan *pb.UntrustedMessage +} + +func (m *mockEnclave) SendMessage(p *pb.UntrustedMessage) error { + m.uch <- p + return nil +} + +func TestServerMockEnclave(t *testing.T) { + m := mockEnclave{make(chan *pb.EnclaveMessage), make(chan *pb.UntrustedMessage)} + txGen := &util.TxGenerator{} + dispatcher := dispatch.New(config.RaftHostConfig{ + RefreshAttestationDuration: time.Minute, + TickDuration: time.Minute, + MetricPollDuration: time.Minute, + EnclaveConcurrency: 3, + }, txGen, &m, m.ech) + go dispatcher.Run(context.Background(), &mockPeerSender{}) + + mux := http.NewServeMux() + mux.Handle("/v1/enclave", handlers.NewWebsocket(&config.Default().Request, dispatcher)) + + ts := httptest.NewServer(mux) + defer ts.Close() + + u := url.URL{Scheme: "ws", Host: ts.Listener.Addr().String(), Path: "v1/enclave"} + log.Println(u.String()) + hdrs := http.Header{} + hdrs.Add("Authorization", fmt.Sprintf("Basic %s", base64.URLEncoding.EncodeToString([]byte("00112233445566778899aabbccddeeff:foo")))) + c, resp, err := websocket.DefaultDialer.Dial(u.String(), hdrs) + if err != nil { + t.Fatalf("dial: %v %v", err, resp.StatusCode) + } + + req := <-m.uch + if req1, ok := req.Inner.(*pb.UntrustedMessage_H2ERequest); !ok { + t.Errorf("not HostToEnclaveRequest: %v", req) + } else if _, ok := req1.H2ERequest.Inner.(*pb.HostToEnclaveRequest_NewClient); !ok { + t.Errorf("not NewClient: %v", req) + } + + // send a handshake response + m.ech <- &pb.EnclaveMessage{Inner: &pb.EnclaveMessage_H2EResponse{ + H2EResponse: &pb.HostToEnclaveResponse{ + Inner: &pb.HostToEnclaveResponse_NewClientReply{NewClientReply: &pb.NewClientReply{ + ClientId: 123, + HandshakeStart: &pb.ClientHandshakeStart{ + Evidence: []byte{1}, + Endorsement: []byte{2}, + }, + }}, + RequestId: req.GetH2ERequest().RequestId, + }, + }} + _, msg, _ := c.ReadMessage() + var start pb.ClientHandshakeStart + if err := proto.Unmarshal(msg, &start); err != nil { + t.Error("unmarshal: ", err) + } +} diff --git a/shared/.gitignore b/shared/.gitignore new file mode 100644 index 0000000..63d5cab --- /dev/null +++ b/shared/.gitignore @@ -0,0 +1,4 @@ +*.zip +protoc* +!sha256.* +go* diff --git a/shared/proto/client.proto b/shared/proto/client.proto new file mode 100644 index 0000000..5e8d078 --- /dev/null +++ b/shared/proto/client.proto @@ -0,0 +1,114 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +syntax = "proto3"; + +package svr2.client; +option go_package = "github.com/signalapp/svr2/proto"; +option optimize_for = LITE_RUNTIME; + +// Client protocol for SecureValueRecovery2 (SVR2). +// +// A client that wishes to store a secret should call Backup, +// followed by Expose. Backup adds the secret to the store, +// and Expose makes that secret available for restoration. Subsequent +// calls to Restore allow the client to retrieve the stored secret. +// Calls to Backup and Expose are idempotent. A client should call +// Backup until it succeeds, then call Expose until it succeeds. +// A client should NOT call Backup+Expose until the pair succeeds, +// as doing so potentially resets the `tries` counter and allows +// for brute force attempts against the pin. If Backup succeeds but +// a subsequent Expose call returns status=ERROR, the client can be +// pretty sure that something nefarious is going on, and they should +// reconsider storing their secret. + +message Request { + oneof inner { + BackupRequest backup = 2; + ExposeRequest expose = 5; + RestoreRequest restore = 3; + DeleteRequest delete = 4; + } +} + +message Response { + oneof inner { + BackupResponse backup = 1; + ExposeResponse expose = 4; + RestoreResponse restore = 2; + DeleteResponse delete = 3; + } +} + +// +// backup +// + +message BackupRequest { + bytes data = 1; // between 16 and 48 bytes + bytes pin = 2; // 32 bytes + uint32 max_tries = 3; // [1,255] +} + +message BackupResponse { + enum Status { + UNSET = 0; // never returned + OK = 1; // successfully set db[backup_id]=data + } + + Status status = 1; +} + +// +// restore +// + +message RestoreRequest { + bytes pin = 1; // 32 bytes +} + +message RestoreResponse { + enum Status { + UNSET = 0; // never returned + OK = 1; // successfully restored, [data] will be set + MISSING = 2; // db[backup_id] does not exist + PIN_MISMATCH = 3; // pin did not match, tries were decremented + } + + Status status = 1; + bytes data = 2; // between 16 and 48 bytes, if set + uint32 tries = 3; // in range [0, 255] +} + +// +// delete +// + +message DeleteRequest { +} + +message DeleteResponse { +} + +// +// expose +// + +message ExposeRequest { + bytes data = 1; +} + +message ExposeResponse { + enum Status { + UNSET = 0; // never returned + OK = 1; // successfully restored, [data] will be set + + // If this status comes back after a successful Backup() call, + // this should be cause for concern. + // It means that someone has either reset, deleted, or tried to brute-force + // the backup since it was created. + ERROR = 2; + } + + Status status = 1; +} diff --git a/shared/proto/client3.proto b/shared/proto/client3.proto new file mode 100644 index 0000000..232e793 --- /dev/null +++ b/shared/proto/client3.proto @@ -0,0 +1,77 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +syntax = "proto3"; + +package svr2.client; +option go_package = "github.com/signalapp/svr2/proto"; +option optimize_for = LITE_RUNTIME; + +// Client protocol for SVR3. + +message Request3 { + oneof inner { + CreateRequest create = 2; + EvaluateRequest evaluate = 3; + RemoveRequest remove = 4; + } +} + +message Response3 { + oneof inner { + CreateResponse create = 1; + EvaluateResponse evaluate = 2; + RemoveResponse remove = 4; + } +} + +// +// create +// + +message CreateRequest { + uint32 max_tries = 1; + bytes blinded_element = 2; // ristretto255 element, 32 bytes +} + +message CreateResponse { + enum Status { + UNSET = 0; + OK = 1; + INVALID_REQUEST = 2; + ERROR = 3; + } + Status status = 1; + bytes evaluated_element = 2; // ristretto255 element, 32 bytes + bytes public_key = 3; // ristretto255 key, 32 bytes +} + +// +// evaluate +// + +message EvaluateRequest { + bytes blinded_element = 1; // ristretto255 element, 32 bytes +} + +message EvaluateResponse { + enum Status { + UNSET = 0; + OK = 1; + MISSING = 2; + INVALID_REQUEST = 3; + ERROR = 4; + } + Status status = 1; + bytes evaluated_element = 2; // ristretto255 element, 32 bytes + uint32 tries_remaining = 3; +} + +// +// remove +// + +message RemoveRequest { +} +message RemoveResponse { +} diff --git a/shared/proto/enclaveconfig.proto b/shared/proto/enclaveconfig.proto new file mode 100644 index 0000000..b762520 --- /dev/null +++ b/shared/proto/enclaveconfig.proto @@ -0,0 +1,92 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +syntax = "proto3"; + +package svr2.enclaveconfig; +option go_package = "github.com/signalapp/svr2/proto"; +option optimize_for = LITE_RUNTIME; + +// Should match 's oe_log_level_t +// (https://github.com/openenclave/openenclave/blob/master/include/openenclave/log.h) +enum EnclaveLogLevel { + LOG_LEVEL_NONE = 0; + LOG_LEVEL_FATAL = 1; + LOG_LEVEL_ERROR = 2; + LOG_LEVEL_WARNING = 3; + LOG_LEVEL_INFO = 4; + LOG_LEVEL_DEBUG = 5; + LOG_LEVEL_VERBOSE = 6; + LOG_LEVEL_MAX = 7; +} + +enum DatabaseVersion { + DATABASE_VERSION_UNKNOWN = 0; + DATABASE_VERSION_SVR2 = 2; + DATABASE_VERSION_SVR3 = 3; +} + +message RaftConfig { + uint32 election_ticks = 1; + uint32 heartbeat_ticks = 2; + uint32 replication_chunk_bytes = 3; + uint32 replica_voting_timeout_ticks = 4; + uint32 replica_membership_timeout_ticks = 5; + uint64 log_max_bytes = 6; + uint32 replication_pipeline = 7; +} + +message EnclaveConfig { + RaftConfig raft = 1; + // Enclave-to-enclave transactions will time out after this many ticks. + uint32 e2e_txn_timeout_ticks = 2; + // Every N ticks, send our local timestamp to our peers. + uint32 send_timestamp_ticks = 5; +} + +// RaftGroupConfig is a configuration shared by members of a Raft group. +// It's created only once, on creation of the Raft group. From that +// point forward, it's shared between replicas as they're added to the +// group, and it's not possible to modify it externally. +message RaftGroupConfig { + // When creating a new group, don't fill this in; it'll be randomly generated. + // This will be passed to other replicas as they join. + fixed64 group_id = 1; + // This raft group will refuse to serve client request with + // tags = 2; + uint64 v = 3; +} + +message MetricsPB { + repeated U64PB counters = 1; + repeated U64PB gauges = 2; +} diff --git a/shared/proto/msgs.proto b/shared/proto/msgs.proto new file mode 100644 index 0000000..9e16091 --- /dev/null +++ b/shared/proto/msgs.proto @@ -0,0 +1,210 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +syntax = "proto3"; + +package svr2; +option go_package = "github.com/signalapp/svr2/proto"; +option optimize_for = LITE_RUNTIME; + +import "enclaveconfig.proto"; +import "error.proto"; +import "metrics.proto"; + +// +// shared types +// + +// +// UntrustedMessages are messages from the host to the enclave +// +message UntrustedMessage { + oneof inner { + PeerMessage peer_message = 1; + TimerTick timer_tick = 2; + EnclavePeer reset_peer = 3; + HostToEnclaveRequest h2e_request = 4; + } +} + +message PeerMessage { + bytes peer_id = 1; + oneof inner { + bytes syn = 2; + bytes synack = 3; + bytes data = 4; + bool rst = 5; + } +} + +message TimerTick { + fixed64 new_timestamp_unix_secs = 1; +} + +message EnclavePeer { + bytes peer_id = 1; +} + +message HashResponse { + bytes db_hash = 1; + int64 commit_idx = 2; + bytes commit_hash_chain = 3; +} + +message HostToEnclaveRequest { + uint64 request_id = 1; + oneof inner { + enclaveconfig.EnclaveConfig reconfigure = 2; + bool get_enclave_status = 3; + NewClientRequest new_client = 4; + ExistingClientRequest existing_client = 5; + CloseClientRequest close_client = 6; + DatabaseRequest database_request = 7; + bool create_new_raft_group = 8; + JoinRaftRequest join_raft = 9; + EnclavePeer ping_peer = 10; + bool request_voting = 11; + bool request_metrics = 12; + RefreshAttestation refresh_attestation = 13; + enclaveconfig.EnclaveLogLevel set_log_level = 14; + bool relinquish_leadership = 15; + bool request_removal = 16; + bool hashes = 17; + } +} + +message HostToEnclaveResponse { + uint64 request_id = 1; // corresponds to a HostToEnclaveRequest request_id + oneof inner { + error.Error status = 2; // status returnable from any transaction. OK==success, otherwise failure + EnclaveReplicaStatus get_enclave_status_reply = 3; + NewClientReply new_client_reply = 4; + ExistingClientReply existing_client_reply = 5; + metrics.MetricsPB metrics_reply = 6; + HashResponse hashes = 7; + } +} + +// +// Enclave status top-level messages +// +message EnclaveReplicaStatus { + repeated EnclavePeerStatus peers = 1; + RaftState raft_state = 2; +} + +// +// Client messages +// +message NewClientRequest { + bytes client_authenticated_id = 1; +} + +message ClientHandshakeStart { + // Public key associated with this server's enclave. For use in test-only + // contexts where attestation is not available + bytes test_only_pubkey = 1; + + // Remote-attestation evidence associated with the public key + bytes evidence = 2; + + // Endorsements of remote-attestation evidence. + bytes endorsement = 3; +} + +message NewClientReply { + uint64 client_id = 1; + ClientHandshakeStart handshake_start = 2; +} + +message ExistingClientRequest { + uint64 client_id = 1; + bytes data = 2; +} +message ExistingClientReply { + bytes data = 1; + bool fin = 2; // if true, close client (success) after sending data. +} + +message CloseClientRequest { + uint64 client_id = 1; +} + +// +// HostToEnclaveRequest control messages +// +message DatabaseRequest { + bytes authenticated_id = 1; + bytes request = 2; +} + +message JoinRaftRequest { + bytes peer_id = 1; +} + +message RefreshAttestation { + bool rotate_client_key = 1; +} + +// +// enclave messages - messages from the enclave to the host in response to an +// incoming UntrustedMessage. If inner is a PeerMessage it will be forwarded. If it +// is a HostToEnclaveResponse it will be processed by the host. +// +message EnclaveMessage { + oneof inner { + PeerMessage peer_message = 1; + HostToEnclaveResponse h2e_response = 2; + } +} + +enum PeerState { + PEER_DISCONNECTED = 0; + PEER_CONNECTING = 1; + PEER_CONNECTED = 2; +} + +// +// enclave status submessages +// +enum RaftState { + // In NO_STATE, there is no raft_, log_, or db_. We're not sure whether + // we're going to create a new raft group, or join an existing one. + RAFTSTATE_NO_STATE = 0; + // In WAITING_FOR_FIRST_CONNECTION, we're waiting for our first peer connection + // to a Raft replica. + RAFTSTATE_WAITING_FOR_FIRST_CONNECTION = 1; + // In LOADING, we're replicating logs and db. + RAFTSTATE_LOADING = 2; + // In LOADED_REQUESTING_MEMBERSHIP, we have received a full state + // from an existing replica, and we're now trying to join the group. + // We're ready to process incoming RaftMessage messages, and we're + // watching our log to see when we become a full member (by watching + // for a ReplicaGroup in the log with our peer ID in it) + RAFTSTATE_LOADED_REQUESTING_MEMBERSHIP = 3; + // In LOADING_PART_OF_GROUP, we're now a full-fledged member of a Raft group. + // We may or may not have voting rights. + RAFTSTATE_LOADED_PART_OF_GROUP = 4; +} + +message ConnectionStatus { + PeerState state = 1; + uint64 last_attestation_unix_secs = 2; +} + +message EnclavePeerStatus { + bytes peer_id = 1; + EnclavePeerReplicationStatus replication_status = 2; + bool in_raft = 3; + bool is_leader = 4; + bool is_voting = 5; + bool me = 6; + ConnectionStatus connection_status = 7; +} + +message EnclavePeerReplicationStatus { + uint64 next_index = 1; + uint64 match_index = 2; + uint64 inflight_index = 3; + bool probing = 4; +} diff --git a/shared/proto/nitro.proto b/shared/proto/nitro.proto new file mode 100644 index 0000000..dcaea5f --- /dev/null +++ b/shared/proto/nitro.proto @@ -0,0 +1,38 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +syntax = "proto3"; + +package svr2.nitro; +option go_package = "github.com/signalapp/svr2/proto"; +option optimize_for = LITE_RUNTIME; +import "error.proto"; +import "enclaveconfig.proto"; + +message InboundMessage { + oneof inner { + enclaveconfig.InitConfig init = 1; + MsgCallRequest msg = 2; + } +} +message OutboundMessage { + oneof inner { + InitCallResponse init = 1; + MsgCallResponse msg = 2; + bytes out = 3; + } +} + +message InitCallResponse { + // there's no `status` here, because a failure to init will crash. + bytes peer_id = 1; +} + +message MsgCallRequest { + uint64 id = 1; + bytes data = 2; // A serialized UntrustedMessage +} +message MsgCallResponse { + uint64 id = 1; + error.Error status = 2; +} diff --git a/shared/svr2.edl b/shared/svr2.edl new file mode 100644 index 0000000..986734c --- /dev/null +++ b/shared/svr2.edl @@ -0,0 +1,47 @@ +enclave { + from "openenclave/edl/syscall.edl" import *; + from "openenclave/edl/logging.edl" import *; + from "platform.edl" import *; + + trusted { + // svr2_init initiates the enclave. + // + // Args: + // enclave_id: Unique identifier for this enclave (will be passed out + // in svr2_output_message calls to differentiate calls from multiple + // enclaves. + // config{,_size}: Serialized EnclaveConfig protobuf. + // peer_id{,32}: The peer_id is the public key of anenclave generated + // key pair and is generated internally + // Returns: error::Error as int. + public int svr2_init( + size_t config_size, + [in, size=config_size] unsigned char* config, + [out, size=32] unsigned char* peer_id); + + // svr2_input_message sends a message from host->enclave. + // Should not be called concurrently. The enclave won't care, but + // the caller won't know which svr2_output_message is associated with + // which input message, as enclave-side locking will be opaque. + // + // Args: + // msg{,_size}: Serialized HostToEnclaveMessage. + // Returns: error::Error as int. + public int svr2_input_message( + size_t msg_size, + [in, size=msg_size] unsigned char* msg); + }; + + untrusted { + // svr2_output_message sends a message from enclave->host. + // It will only be called during a call to svr2_input_message. + // + // Args: + // msg{,_size}: Serialized EnclaveToHostMessage. + void svr2_output_message( + size_t msg_size, + [in, size=msg_size] unsigned char* msg); + }; +}; + +