SecureValueRecovery2/enclave/socketmain/socketmain.cc

240 lines
7.9 KiB
C++

// Copyright 2023 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
#include <sys/socket.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <arpa/inet.h>
#include <linux/vm_sockets.h>
#include <pthread.h>
#include <errno.h>
#include <unistd.h>
#include <cstdlib>
#include "env/env.h"
#include "core/core.h"
#include "context/context.h"
#include "proto/enclaveconfig.pb.h"
#include "util/log.h"
#include "util/bytes.h"
#include "proto/socketmain.pb.h"
#include "socketwrap/socket.h"
#include "env/socket/socket.h"
#include "queue/queue.h"
namespace svr2 {
#define RETURN_ERRNO_ERROR_UNLESS(x, err) do { \
if (!(x)) { \
int e = errno; \
LOG(ERROR) << "(" << #x << ") evaluated to false, errno(" << e << "): " << strerror(e); \
return COUNTED_ERROR(err); \
} \
} while (0)
// To simplify our server, this function creates the appropriate
// AF_VSOCK, binds it, listens, accepts, then returns the accepted
// file descriptor, closing the listener. We know that if this
// socket dies, we stop serving, so there's no need to create an
// accept loop.
error::Error AcceptSocket(int sock_type, int port, int* afd) {
int fd;
RETURN_ERRNO_ERROR_UNLESS(
0 < (fd = socket(sock_type, SOCK_STREAM, 0)),
SocketMain_SocketCreation);
struct sockaddr* addr;
socklen_t addr_size;
struct sockaddr_vm vm_addr;
struct sockaddr_in in_addr;
switch (sock_type) {
case AF_INET:
memset(&in_addr, 0, sizeof(in_addr));
in_addr.sin_family = AF_INET;
in_addr.sin_port = htons(port);
in_addr.sin_addr.s_addr = INADDR_ANY;
addr = reinterpret_cast<struct sockaddr*>(&in_addr);
addr_size = sizeof(in_addr);
break;
case AF_VSOCK:
memset(&vm_addr, 0, sizeof(vm_addr));
vm_addr.svm_family = AF_VSOCK;
vm_addr.svm_port = port;
vm_addr.svm_cid = VMADDR_CID_ANY;
addr = reinterpret_cast<struct sockaddr*>(&vm_addr);
addr_size = sizeof(vm_addr);
break;
default:
return COUNTED_ERROR(SocketMain_UnsupportedSockType);
}
LOG(INFO) << "Binding to port " << port;
RETURN_ERRNO_ERROR_UNLESS(
0 == bind(fd, addr, addr_size),
SocketMain_SocketBind);
RETURN_ERRNO_ERROR_UNLESS(
0 == listen(fd, 10),
SocketMain_SocketListen);
*afd = 0;
socklen_t initial_size = addr_size;
while (*afd <= 0) {
LOG(INFO) << "Accepting...";
addr_size = initial_size;
memset(addr, 0, addr_size);
*afd = accept4(fd, addr, &addr_size, SOCK_CLOEXEC);
RETURN_ERRNO_ERROR_UNLESS(
*afd > 0 || errno == EINTR || errno == ECONNABORTED,
SocketMain_SocketAccept);
uint8_t buf[1] = {0};
auto got = recv(*afd, buf, 1, 0);
if (got == 0) {
LOG(INFO) << "Socket opened then closed without any data being sent, assuming a health check";
close(*afd);
*afd = 0;
} else {
RETURN_ERRNO_ERROR_UNLESS(got == 1, SocketMain_SocketAccept);
if (buf[0] != 'N') {
LOG(ERROR) << "Missing socketmain hello byte";
return COUNTED_ERROR(SocketMain_SocketAccept);
}
break;
}
}
shutdown(fd, SHUT_RDWR);
close(fd);
if (sock_type == AF_INET) {
int tcp_nodelay = 1;
RETURN_ERRNO_ERROR_UNLESS(
0 == setsockopt(*afd, IPPROTO_TCP, TCP_NODELAY, &tcp_nodelay, sizeof(tcp_nodelay)),
SocketMain_SocketSetOpt);
int tcp_keepalive = 1;
int tcp_keepalive_idle = 60; // send first probe after 1m of inactivity
int tcp_keepalive_intvl = 30; // send subsequent probes every 30s
int tcp_keepalive_cnt = 8; // fail if 8 probes are unack'd. This totals ~5m of total time
RETURN_ERRNO_ERROR_UNLESS(
0 == setsockopt(*afd, SOL_SOCKET, SO_KEEPALIVE, &tcp_keepalive, sizeof(tcp_keepalive)),
SocketMain_SocketSetOpt);
RETURN_ERRNO_ERROR_UNLESS(
0 == setsockopt(*afd, SOL_TCP, TCP_KEEPIDLE, &tcp_keepalive_idle, sizeof(tcp_keepalive_idle)),
SocketMain_SocketSetOpt);
RETURN_ERRNO_ERROR_UNLESS(
0 == setsockopt(*afd, SOL_TCP, TCP_KEEPINTVL, &tcp_keepalive_intvl, sizeof(tcp_keepalive_intvl)),
SocketMain_SocketSetOpt);
RETURN_ERRNO_ERROR_UNLESS(
0 == setsockopt(*afd, SOL_TCP, TCP_KEEPCNT, &tcp_keepalive_cnt, sizeof(tcp_keepalive_cnt)),
SocketMain_SocketSetOpt);
}
LOG(INFO) << "Sucessfully accepted connection on FD=" << *afd;
return error::OK;
}
error::Error RunServerThread(core::Core* core, socketwrap::Socket* sock) {
while (true) {
context::Context ctx;
auto in = ctx.Protobuf<socketmain::InboundMessage>();
RETURN_IF_ERROR(sock->ReadPB(&ctx, in));
if (in->inner_case() != socketmain::InboundMessage::kMsg) {
return COUNTED_ERROR(SocketMain_InboundNotMessage);
}
auto msg = ctx.Protobuf<UntrustedMessage>();
if (!msg->ParseFromString(in->mutable_msg()->data())) {
return COUNTED_ERROR(SocketMain_InboundMessageParse);
}
auto status = core->Receive(&ctx, *msg);
auto out = ctx.Protobuf<socketmain::OutboundMessage>();
auto out_msg = out->mutable_msg();
out_msg->set_id(in->msg().id());
out_msg->set_status(status);
RETURN_IF_ERROR(env::socket::SendOutboundMessage(&ctx, *out));
}
}
// Read an init message from a socket and use it to create a new core object.
std::pair<std::unique_ptr<core::Core>, error::Error> InitCore(socketwrap::Socket* sock) {
context::Context ctx;
auto pb = ctx.Protobuf<socketmain::InboundMessage>();
LOG(INFO) << "Reading init message";
if (error::Error err = sock->ReadPB(&ctx, pb); err != error::OK) {
return std::make_pair(nullptr, err);
}
if (pb->inner_case() != socketmain::InboundMessage::kInit) {
return std::make_pair(nullptr, COUNTED_ERROR(SocketMain_InboundNotInit));
}
auto init = pb->init();
if (init.initial_log_level() != enclaveconfig::LOG_LEVEL_NONE) {
util::SetLogLevel(init.initial_log_level());
}
env::Init(init.group_config().simulated());
LOG(INFO) << "Creating core";
auto [core_ptr, err] = core::Core::Create(
&ctx,
init);
if (err == error::OK) {
LOG(INFO) << "Writing init message";
auto out = ctx.Protobuf<socketmain::OutboundMessage>();
core_ptr->ID().ToString(out->mutable_init()->mutable_peer_id());
err = env::socket::SendOutboundMessage(&ctx, *out);
}
LOG(INFO) << "Core creation: " << err;
return std::make_pair(std::move(core_ptr), err);
}
// Run a server, returning an error when it dies.
error::Error RunServer(int port, int sock_type) {
int fd;
RETURN_IF_ERROR(AcceptSocket(sock_type, port, &fd));
socketwrap::Socket sock(fd);
auto sockp = &sock;
std::vector<std::thread> threads;
threads.emplace_back([sockp]{
LOG(INFO) << "Starting thread to send NSM messages";
LOG(FATAL) << env::socket::SendSocketMessages(sockp);
});
auto [c, err] = InitCore(&sock);
RETURN_IF_ERROR(err);
auto cp = c.get();
for (size_t i = 0; i < 32 /* chosen by random dice roll */; i++) {
threads.emplace_back([cp, sockp]{
LOG(FATAL) << RunServerThread(cp, sockp);
});
}
for (size_t i = 0; i < threads.size(); i++) {
threads[i].join();
}
return error::OK; // unreachable
}
} // namespace svr2
int main(int argc, char** argv) {
int port = 27427;
int sock_type = 0;
for (int i = 1; i < argc; i++) {
std::string arg(argv[i]);
if (arg.rfind("--port=", 0) == 0) {
port = atoi(arg.data() + strlen("--port="));
if (port > 0 && port < 65536) {
LOG(INFO) << "Running on port " << port;
continue;
}
} else if (arg == "--sock_type=af_inet") {
LOG(INFO) << "Using socket type 'af_inet'";
sock_type = AF_INET;
continue;
} else if (arg == "--sock_type=af_vsock") {
LOG(INFO) << "Using socket type 'af_vsock'";
sock_type = AF_VSOCK;
continue;
}
LOG(FATAL) << "Usage: " << argv[0]
<< " --sock_type={af_inet,af_vsock} [--port=###]";
}
if (sock_type == 0) {
LOG(FATAL) << "socket type not set, use --sock_type=xxx";
}
auto err = svr2::RunServer(port, sock_type);
LOG(FATAL) << err;
return -1;
}