Signal-iOS/SignalServiceKit/SecureValueRecovery/SgxWebsocketConnection.swift
2026-06-05 08:15:59 -05:00

179 lines
6.5 KiB
Swift

//
// Copyright 2023 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//
import Foundation
import LibSignalClient
import SwiftProtobuf
/// Exposes a SgxClient-conformant server communication channel.
///
/// This handles the initial handshake & subsequent encryption/decryption
/// of the exchanged messages using an `SgxClient` instance provided
/// by a `SgxWebsocketConfigurator`.
///
/// While this is a class, there should never be an instance of this base class; all instances
/// should be of a concrete subclass. It is only a class and not a protocol so users can refer
/// to an instance by config type without specifying the implementation,
/// e.g. `SgxWebsocketConnection<FooServerConfigurator>`.
/// That is not possible for a protocol with an associated type.
public class SgxWebsocketConnection<Configurator: SgxWebsocketConfigurator> {
// Never add an initializer to this class; instances should be impossible.
fileprivate init() {}
public var mrEnclave: MrEnclave { fatalError("Concrete subclass must implement") }
public var client: Configurator.Client { fatalError("Concrete subclass must implement") }
public var auth: RemoteAttestationAuth { fatalError("Concrete subclass must implement") }
// Subclasses must implement.
func sendRequestAndReadResponse(_ request: Configurator.Request) async throws -> Configurator.Response {
fatalError("Concrete subclass must implement")
}
// Subclasses must implement.
func disconnect(code: URLSessionWebSocketTask.CloseCode?) {
fatalError("Concrete subclass must implement")
}
}
public class SgxWebsocketConnectionImpl<Configurator: SgxWebsocketConfigurator>: SgxWebsocketConnection<Configurator> {
private let webSocket: WebSocketPromise
private let configurator: Configurator
private let _client: Configurator.Client
private let _auth: RemoteAttestationAuth
private init(
webSocket: WebSocketPromise,
configurator: Configurator,
client: Configurator.Client,
auth: RemoteAttestationAuth,
) {
self.webSocket = webSocket
self.configurator = configurator
self._client = client
self._auth = auth
super.init()
}
static func connectAndPerformHandshake(
configurator: Configurator,
auth: RemoteAttestationAuth,
websocketFactory: WebSocketFactory,
) async throws -> SgxWebsocketConnection<Configurator> {
let webSocket = try buildSocket(
configurator: configurator,
auth: auth,
websocketFactory: websocketFactory,
)
do {
let attestationMessage = try await webSocket.waitForResponse().awaitable()
let client = try Configurator.client(
mrenclave: configurator.mrenclave,
attestationMessage: attestationMessage,
currentDate: Date(),
)
webSocket.send(data: client.initialRequest())
let handshakeResponse = try await webSocket.waitForResponse().awaitable()
try client.completeHandshake(handshakeResponse)
return SgxWebsocketConnectionImpl<Configurator>(
webSocket: webSocket,
configurator: configurator,
client: client,
auth: auth,
)
} catch {
Logger.warn("\(type(of: configurator).loggingName): Disconnecting socket after failed handshake: \(error)")
webSocket.disconnect(code: .invalidFramePayloadData)
throw error
}
}
private static func buildSocket(
configurator: Configurator,
auth: RemoteAttestationAuth,
websocketFactory: WebSocketFactory,
) throws -> WebSocketPromise {
let authHeaderValue = HttpHeaders.authHeaderValue(username: auth.username, password: auth.password)
let request = WebSocketRequest(
signalService: Configurator.signalServiceType,
urlPath: Configurator.websocketUrlPath(mrenclaveString: configurator.mrenclave.dataValue.hexadecimalString),
urlQueryItems: nil,
extraHeaders: [HttpHeaders.authHeaderKey: authHeaderValue],
)
guard let webSocketPromise = websocketFactory.webSocketPromise(request: request, callbackScheduler: DispatchQueue.global()) else {
throw OWSAssertionError("We should always be able to get a web socket from this API.")
}
return webSocketPromise
}
override public var mrEnclave: MrEnclave { return configurator.mrenclave }
override public var client: Configurator.Client { return _client }
override public var auth: RemoteAttestationAuth { return _auth }
override public func sendRequestAndReadResponse(
_ request: Configurator.Request,
) async throws -> Configurator.Response {
try self.encryptAndSendRequest(request.serializedData())
let encryptedResponse = try await self.webSocket.waitForResponse().awaitable()
let data = try self.decryptResponse(encryptedResponse)
return try Configurator.Response(serializedBytes: data)
}
private func encryptAndSendRequest(_ request: Data) throws {
let encryptedRequest = try client.establishedSend(request)
webSocket.send(data: encryptedRequest)
}
private func decryptResponse(_ encryptedResponse: Data) throws -> Data {
return try client.establishedRecv(encryptedResponse)
}
override public func disconnect(code: URLSessionWebSocketTask.CloseCode?) {
webSocket.disconnect(code: code)
}
}
#if TESTABLE_BUILD
public class MockSgxWebsocketConnection<Configurator: SgxWebsocketConfigurator>: SgxWebsocketConnection<Configurator> {
override init() {
super.init()
}
public var mockEnclave: MrEnclave!
override public var mrEnclave: MrEnclave { return mockEnclave }
public var mockClient: Configurator.Client!
override public var client: Configurator.Client { return mockClient }
public var mockAuth: RemoteAttestationAuth!
override public var auth: RemoteAttestationAuth { return mockAuth }
public var onSendRequestAndReadResponse: ((Configurator.Request) -> Promise<Configurator.Response>)?
override public func sendRequestAndReadResponse(
_ request: Configurator.Request,
) async throws -> Configurator.Response {
try await onSendRequestAndReadResponse!(request).awaitable()
}
public var onDisconnect: (() -> Void)?
override public func disconnect(code: URLSessionWebSocketTask.CloseCode?) {
onDisconnect?()
}
}
#endif