De-singleton RemoteAttestation

This commit is contained in:
Pete Walters 2026-06-05 08:15:59 -05:00 committed by GitHub
parent f40bc944ae
commit 1954342a36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 183 additions and 181 deletions

View File

@ -3822,7 +3822,7 @@
F9C5CCA3289453B300548EEE /* StorageService.pb.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5C9B8289453B100548EEE /* StorageService.pb.swift */; };
F9C5CCA4289453B300548EEE /* SSKProto+OWS.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5C9B9289453B100548EEE /* SSKProto+OWS.swift */; };
F9C5CCAC289453B300548EEE /* PreKeyManager.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5C9C2289453B100548EEE /* PreKeyManager.swift */; };
F9C5CCB0289453B300548EEE /* RemoteAttestation.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5C9C7289453B100548EEE /* RemoteAttestation.swift */; };
F9C5CCB0289453B300548EEE /* RemoteAttestationAuthFetcher.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5C9C7289453B100548EEE /* RemoteAttestationAuthFetcher.swift */; };
F9C5CCC0289453B300548EEE /* ContactDiscoveryTask.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5C9D9289453B100548EEE /* ContactDiscoveryTask.swift */; };
F9C5CCC3289453B300548EEE /* ContactDiscoveryError.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5C9DC289453B100548EEE /* ContactDiscoveryError.swift */; };
F9C5CCC5289453B300548EEE /* SignalAccount.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5C9DE289453B100548EEE /* SignalAccount.swift */; };
@ -8136,7 +8136,7 @@
F9C5C9B8289453B100548EEE /* StorageService.pb.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = StorageService.pb.swift; sourceTree = "<group>"; };
F9C5C9B9289453B100548EEE /* SSKProto+OWS.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = "SSKProto+OWS.swift"; sourceTree = "<group>"; };
F9C5C9C2289453B100548EEE /* PreKeyManager.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PreKeyManager.swift; sourceTree = "<group>"; };
F9C5C9C7289453B100548EEE /* RemoteAttestation.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = RemoteAttestation.swift; sourceTree = "<group>"; };
F9C5C9C7289453B100548EEE /* RemoteAttestationAuthFetcher.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = RemoteAttestationAuthFetcher.swift; sourceTree = "<group>"; };
F9C5C9D9289453B100548EEE /* ContactDiscoveryTask.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ContactDiscoveryTask.swift; sourceTree = "<group>"; };
F9C5C9DC289453B100548EEE /* ContactDiscoveryError.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ContactDiscoveryError.swift; sourceTree = "<group>"; };
F9C5C9DE289453B100548EEE /* SignalAccount.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = SignalAccount.swift; sourceTree = "<group>"; };
@ -15367,7 +15367,7 @@
F9D5BFCC2979A017001737E5 /* OWSRequestFactory+Spam.swift */,
D95C39E7296DEBFB00A9DA23 /* OWSRequestFactory+Usernames.swift */,
F9C5CAE2289453B200548EEE /* OWSRequestFactory.swift */,
F9C5C9C7289453B100548EEE /* RemoteAttestation.swift */,
F9C5C9C7289453B100548EEE /* RemoteAttestationAuthFetcher.swift */,
66C2B1302A05D28A008DDE72 /* TSRequest.swift */,
);
path = Requests;
@ -19723,7 +19723,7 @@
6646573B2AC388C70099DE1C /* RegistrationStateChangeManager.swift in Sources */,
6646573D2AC3894D0099DE1C /* RegistrationStateChangeManagerImpl.swift in Sources */,
040506FC2F7FE3DB0078B769 /* RemoteAnnouncementModel.swift in Sources */,
F9C5CCB0289453B300548EEE /* RemoteAttestation.swift in Sources */,
F9C5CCB0289453B300548EEE /* RemoteAttestationAuthFetcher.swift in Sources */,
F9C5CE17289453B400548EEE /* RemoteConfigManager.swift in Sources */,
D98DD86028EE53B00089333E /* RemoteMegaphoneModel.swift in Sources */,
040507012F804C240078B769 /* RemoteReleaseNotesService.swift in Sources */,

View File

@ -3364,9 +3364,9 @@ public class RegistrationCoordinatorTest {
// Put some auth credentials in storage.
let svr2CredentialCandidates: [SVR2AuthCredential] = [
Stubs.svr2AuthCredential,
SVR2AuthCredential(credential: RemoteAttestation.Auth(username: "aaaa", password: "abc")),
SVR2AuthCredential(credential: RemoteAttestation.Auth(username: "zzzz", password: "xyz")),
SVR2AuthCredential(credential: RemoteAttestation.Auth(username: "0000", password: "123")),
SVR2AuthCredential(credential: RemoteAttestationAuth(username: "aaaa", password: "abc")),
SVR2AuthCredential(credential: RemoteAttestationAuth(username: "zzzz", password: "xyz")),
SVR2AuthCredential(credential: RemoteAttestationAuth(username: "0000", password: "123")),
]
svrAuthCredentialStore.svr2Dict = Dictionary(grouping: svr2CredentialCandidates, by: \.credential.username).mapValues { $0.first! }
@ -3398,7 +3398,7 @@ public class RegistrationCoordinatorTest {
static let aci = Aci.randomForTesting()
static let pinCode = "1234"
static let svr2AuthCredential = SVR2AuthCredential(credential: RemoteAttestation.Auth(username: "xxx", password: "yyy"))
static let svr2AuthCredential = SVR2AuthCredential(credential: RemoteAttestationAuth(username: "xxx", password: "yyy"))
static let captchaToken = "captchaToken"
static let apnsToken = "apnsToken"

View File

@ -104,6 +104,7 @@ public final class ContactDiscoveryManagerImpl: ContactDiscoveryManager {
recipientFetcher: RecipientFetcher,
recipientManager: any SignalRecipientManager,
recipientMerger: RecipientMerger,
remoteAttestationAuthFetcher: RemoteAttestationAuthFetcher,
tsAccountManager: TSAccountManager,
udManager: OWSUDManager,
libsignalNet: Net,
@ -115,6 +116,7 @@ public final class ContactDiscoveryManagerImpl: ContactDiscoveryManager {
recipientFetcher: recipientFetcher,
recipientManager: recipientManager,
recipientMerger: recipientMerger,
remoteAttestationAuthFetcher: remoteAttestationAuthFetcher,
tsAccountManager: tsAccountManager,
udManager: udManager,
libsignalNet: libsignalNet,

View File

@ -17,6 +17,7 @@ final class ContactDiscoveryTaskQueueImpl: ContactDiscoveryTaskQueue {
private let recipientFetcher: RecipientFetcher
private let recipientManager: any SignalRecipientManager
private let recipientMerger: RecipientMerger
private let remoteAttestationAuthFetcher: RemoteAttestationAuthFetcher
private let tsAccountManager: TSAccountManager
private let udManager: OWSUDManager
private let libsignalNet: Net
@ -27,6 +28,7 @@ final class ContactDiscoveryTaskQueueImpl: ContactDiscoveryTaskQueue {
recipientFetcher: RecipientFetcher,
recipientManager: any SignalRecipientManager,
recipientMerger: RecipientMerger,
remoteAttestationAuthFetcher: RemoteAttestationAuthFetcher,
tsAccountManager: TSAccountManager,
udManager: OWSUDManager,
libsignalNet: Net,
@ -36,6 +38,7 @@ final class ContactDiscoveryTaskQueueImpl: ContactDiscoveryTaskQueue {
self.recipientFetcher = recipientFetcher
self.recipientManager = recipientManager
self.recipientMerger = recipientMerger
self.remoteAttestationAuthFetcher = remoteAttestationAuthFetcher
self.tsAccountManager = tsAccountManager
self.udManager = udManager
self.libsignalNet = libsignalNet
@ -53,7 +56,7 @@ final class ContactDiscoveryTaskQueueImpl: ContactDiscoveryTaskQueue {
mode: mode,
udManager: udManager,
connectionImpl: libsignalNet,
remoteAttestation: ContactDiscoveryV2Operation<LibSignalClient.Net>.Wrappers.RemoteAttestation(),
remoteAttestationAuthFetcher: remoteAttestationAuthFetcher,
).perform()
return try await self.processResults(requestedPhoneNumbers: e164s, discoveryResults: discoveryResults)

View File

@ -101,7 +101,7 @@ final class ContactDiscoveryV2Operation<ConnectionType: ContactDiscoveryConnecti
private let connectionImpl: ConnectionType
private let remoteAttestation: Shims.RemoteAttestation
private let remoteAttestationAuthFetcher: RemoteAttestationAuthFetcher
convenience init(
db: any DB,
@ -109,7 +109,7 @@ final class ContactDiscoveryV2Operation<ConnectionType: ContactDiscoveryConnecti
mode: ContactDiscoveryMode,
udManager: any OWSUDManager,
connectionImpl: ConnectionType,
remoteAttestation: any Shims.RemoteAttestation,
remoteAttestationAuthFetcher: RemoteAttestationAuthFetcher,
) {
self.init(
db: db,
@ -117,7 +117,7 @@ final class ContactDiscoveryV2Operation<ConnectionType: ContactDiscoveryConnecti
persistentState: mode == .oneOffUserRequest ? nil : ContactDiscoveryV2PersistentStateImpl(),
udManager: udManager,
connectionImpl: connectionImpl,
remoteAttestation: remoteAttestation,
remoteAttestationAuthFetcher: remoteAttestationAuthFetcher,
)
}
@ -127,19 +127,22 @@ final class ContactDiscoveryV2Operation<ConnectionType: ContactDiscoveryConnecti
persistentState: (any ContactDiscoveryV2PersistentState)?,
udManager: any OWSUDManager,
connectionImpl: ConnectionType,
remoteAttestation: any Shims.RemoteAttestation,
remoteAttestationAuthFetcher: RemoteAttestationAuthFetcher,
) {
self.db = db
self.e164sToLookup = e164sToLookup
self.persistentState = persistentState
self.udManager = udManager
self.connectionImpl = connectionImpl
self.remoteAttestation = remoteAttestation
self.remoteAttestationAuthFetcher = remoteAttestationAuthFetcher
}
func perform() async throws -> [ContactDiscoveryResult] {
do {
let cdsiAuth = try await self.remoteAttestation.authForCDSI()
let cdsiAuth = try await self.remoteAttestationAuthFetcher.fetchAuth(
forService: .cdsi,
chatServiceAuth: .implicit(),
)
let request = try self.buildRequest()
let auth = LibSignalClient.Auth(username: cdsiAuth.username, password: cdsiAuth.password)
let tokenResult = try await self.connectionImpl.performRequest(request, auth: auth)
@ -322,27 +325,3 @@ private class ContactDiscoveryV2PersistentStateImpl: ContactDiscoveryV2Persisten
}
}
}
// MARK: - Shims
extension ContactDiscoveryV2Operation {
enum Shims {
typealias RemoteAttestation = _ContactDiscoveryV2Operation_RemoteAttestationShim
}
enum Wrappers {
typealias RemoteAttestation = _ContactDiscoveryV2Operation_RemoteAttestationWrapper
}
}
protocol _ContactDiscoveryV2Operation_RemoteAttestationShim {
func authForCDSI() async throws -> RemoteAttestation.Auth
}
class _ContactDiscoveryV2Operation_RemoteAttestationWrapper: _ContactDiscoveryV2Operation_RemoteAttestationShim {
init() {}
func authForCDSI() async throws -> RemoteAttestation.Auth {
return try await RemoteAttestation.authForCDSI()
}
}

View File

@ -437,12 +437,14 @@ extension AppSetup.GlobalsContinuation {
networkManager: networkManager,
)
let remoteAttestationAuthFetcher = RemoteAttestationAuthFetcher(networkManager: networkManager)
let svr = SecureValueRecovery2Impl(
connectionFactory: SgxWebsocketConnectionFactoryImpl(websocketFactory: webSocketFactory),
credentialStorage: svrCredentialStorage,
db: db,
accountKeyStore: accountKeyStore,
pinHasher: LibSignalPinHasher(),
remoteAttestationAuthFetcher: remoteAttestationAuthFetcher,
storageServiceManager: storageServiceManager,
svrLocalStorage: svrLocalStorage,
tsConstants: tsConstants,
@ -1859,6 +1861,7 @@ extension AppSetup.GlobalsContinuation {
recipientFetcher: recipientFetcher,
recipientManager: recipientManager,
recipientMerger: recipientMerger,
remoteAttestationAuthFetcher: remoteAttestationAuthFetcher,
tsAccountManager: tsAccountManager,
udManager: udManager,
libsignalNet: libsignalNet,

View File

@ -328,7 +328,7 @@ public enum RegistrationServiceResponses {
public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
timeRemainingMs = try container.decode(Int.self, forKey: .timeRemainingMs)
let svr2Credential = try container.decode(RemoteAttestation.Auth.self, forKey: .svr2AuthCredential)
let svr2Credential = try container.decode(RemoteAttestationAuth.self, forKey: .svr2AuthCredential)
self.svr2AuthCredential = .init(credential: svr2Credential)
}
}

View File

@ -1,97 +0,0 @@
//
// Copyright 2020 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//
import Foundation
import LibSignalClient
public enum RemoteAttestation {}
// MARK: - CSDI
extension RemoteAttestation {
static func authForCDSI() async throws -> Auth {
return try await Auth.fetch(forService: .cdsi, auth: .implicit())
}
}
// MARK: - SVR2
extension RemoteAttestation {
static func authForSVR2(chatServiceAuth: ChatServiceAuth) async throws -> Auth {
return try await Auth.fetch(forService: .svr2, auth: chatServiceAuth)
}
}
// MARK: - Errors
public extension RemoteAttestation {
enum Error: Swift.Error {
case assertion(reason: String)
}
}
private func attestationError(reason: String) -> RemoteAttestation.Error {
owsFailDebug("Error: \(reason)")
return .assertion(reason: reason)
}
// MARK: - Auth
public extension RemoteAttestation {
struct Auth: Equatable, Codable {
public let username: String
public let password: String
public init(authParamsDict: [String: Any]) throws {
guard let password = authParamsDict["password"] as? String, !password.isEmpty else {
throw attestationError(reason: "missing or empty password")
}
guard let username = authParamsDict["username"] as? String, !username.isEmpty else {
throw attestationError(reason: "missing or empty username")
}
self.init(username: username, password: password)
}
public init(username: String, password: String) {
self.username = username
self.password = password
}
}
}
private extension RemoteAttestation.Auth {
static func fetch(
forService service: RemoteAttestation.Service,
auth: ChatServiceAuth,
) async throws -> RemoteAttestation.Auth {
var request = service.authRequest()
request.auth = .identified(auth)
let response = try await SSKEnvironment.shared.networkManagerRef.asyncRequest(request)
guard let authParamsDict = response.responseBodyDict else {
throw attestationError(reason: "Missing or invalid JSON")
}
return try RemoteAttestation.Auth(authParamsDict: authParamsDict)
}
}
// MARK: - Service
private extension RemoteAttestation {
enum Service {
case cdsi
case svr2
func authRequest() -> TSRequest {
switch self {
case .cdsi: return OWSRequestFactory.remoteAttestationAuthRequestForCDSI()
case .svr2: return OWSRequestFactory.remoteAttestationAuthRequestForSVR2()
}
}
}
}

View File

@ -0,0 +1,69 @@
//
// Copyright 2020 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//
import Foundation
import LibSignalClient
// MARK: - Auth
public struct RemoteAttestationAuth: Equatable, Codable {
public let username: String
public let password: String
init(authParamsDict: [String: Any]) throws {
guard let password = authParamsDict["password"] as? String, !password.isEmpty else {
throw OWSAssertionError("missing or empty password")
}
guard let username = authParamsDict["username"] as? String, !username.isEmpty else {
throw OWSAssertionError("missing or empty username")
}
self.init(username: username, password: password)
}
public init(username: String, password: String) {
self.username = username
self.password = password
}
}
public struct RemoteAttestationAuthFetcher {
let networkManager: any NetworkManagerProtocol
public init(networkManager: any NetworkManagerProtocol) {
self.networkManager = networkManager
}
func fetchAuth(
forService service: Service,
chatServiceAuth: ChatServiceAuth,
) async throws -> RemoteAttestationAuth {
var request = service.authRequest()
request.auth = .identified(chatServiceAuth)
let response = try await networkManager.asyncRequest(request)
guard let authParamsDict = response.responseBodyDict else {
throw OWSAssertionError("Missing or invalid JSON")
}
return try RemoteAttestationAuth(authParamsDict: authParamsDict)
}
// MARK: - Service
public enum Service {
case cdsi
case svr2
func authRequest() -> TSRequest {
switch self {
case .cdsi: return OWSRequestFactory.remoteAttestationAuthRequestForCDSI()
case .svr2: return OWSRequestFactory.remoteAttestationAuthRequestForSVR2()
}
}
}
}

View File

@ -8,9 +8,9 @@ import Foundation
// Transparent wrapper that exists purely to make it clear to readers
// that the credential must be for SVR2, not an arbitrary RemoteAttestation.Auth.
public struct SVR2AuthCredential: Equatable, Codable {
public let credential: RemoteAttestation.Auth
public let credential: RemoteAttestationAuth
public init(credential: RemoteAttestation.Auth) {
public init(credential: RemoteAttestationAuth) {
self.credential = credential
}
}

View File

@ -14,10 +14,16 @@ class SVR2WebsocketConfigurator: SgxWebsocketConfigurator {
let mrenclave: MrEnclave
var authMethod: SVR2.AuthMethod
let remoteAttestationAuthFetcher: RemoteAttestationAuthFetcher
init(mrenclave: MrEnclave, authMethod: SVR2.AuthMethod) {
init(
mrenclave: MrEnclave,
authMethod: SVR2.AuthMethod,
remoteAttestationAuthFetcher: RemoteAttestationAuthFetcher,
) {
self.mrenclave = mrenclave
self.authMethod = authMethod
self.remoteAttestationAuthFetcher = remoteAttestationAuthFetcher
}
static var signalServiceType: SignalServiceType { .svr2 }
@ -26,14 +32,20 @@ class SVR2WebsocketConfigurator: SgxWebsocketConfigurator {
return "v1/\(mrenclaveString)"
}
func fetchAuth() async throws -> RemoteAttestation.Auth {
func fetchAuth() async throws -> RemoteAttestationAuth {
switch authMethod {
case .svrAuth(let credential, _):
return credential.credential
case .chatServerAuth(let authedAccount):
return try await RemoteAttestation.authForSVR2(chatServiceAuth: authedAccount.chatServiceAuth)
return try await remoteAttestationAuthFetcher.fetchAuth(
forService: .svr2,
chatServiceAuth: authedAccount.chatServiceAuth,
)
case .implicit:
return try await RemoteAttestation.authForSVR2(chatServiceAuth: .implicit())
return try await remoteAttestationAuthFetcher.fetchAuth(
forService: .svr2,
chatServiceAuth: .implicit(),
)
}
}

View File

@ -223,8 +223,8 @@ public class SVRAuthCredentialStorageImpl: SVRAuthCredentialStorage {
)
}
private var credential: RemoteAttestation.Auth {
return RemoteAttestation.Auth(username: username, password: password)
private var credential: RemoteAttestationAuth {
return RemoteAttestationAuth(username: username, password: password)
}
func toSVR2Credential() -> SVR2AuthCredential {

View File

@ -15,6 +15,7 @@ public class SecureValueRecovery2Impl: SecureValueRecovery {
private let db: any DB
private let accountKeyStore: AccountKeyStore
private let localStorage: SVRLocalStorage
private let remoteAttestationAuthFetcher: RemoteAttestationAuthFetcher
private let storageServiceManager: StorageServiceManager
private let tsConstants: TSConstantsProtocol
private let twoFAManager: SVR2.Shims.OWS2FAManager
@ -25,6 +26,7 @@ public class SecureValueRecovery2Impl: SecureValueRecovery {
db: any DB,
accountKeyStore: AccountKeyStore,
pinHasher: any SVR2PinHasher,
remoteAttestationAuthFetcher: RemoteAttestationAuthFetcher,
storageServiceManager: StorageServiceManager,
svrLocalStorage: SVRLocalStorage,
tsConstants: TSConstantsProtocol,
@ -36,6 +38,7 @@ public class SecureValueRecovery2Impl: SecureValueRecovery {
self.accountKeyStore = accountKeyStore
self.localStorage = svrLocalStorage
self.pinHasher = pinHasher
self.remoteAttestationAuthFetcher = remoteAttestationAuthFetcher
self.storageServiceManager = storageServiceManager
self.tsConstants = tsConstants
self.twoFAManager = twoFAManager
@ -52,7 +55,10 @@ public class SecureValueRecovery2Impl: SecureValueRecovery {
}
// Force refresh a credential, even if we have one cached, to ensure we
// have a fresh credential to back up.
let credential = try await RemoteAttestation.authForSVR2(chatServiceAuth: .implicit())
let credential = try await remoteAttestationAuthFetcher.fetchAuth(
forService: .svr2,
chatServiceAuth: .implicit(),
)
await db.awaitableWrite { tx in
credentialStorage.storeAuthCredentialForCurrentUsername(
SVR2AuthCredential(credential: credential),
@ -245,7 +251,11 @@ public class SecureValueRecovery2Impl: SecureValueRecovery {
return
}
let config = SVR2WebsocketConfigurator(mrenclave: mrEnclave, authMethod: authMethod)
let config = SVR2WebsocketConfigurator(
mrenclave: mrEnclave,
authMethod: authMethod,
remoteAttestationAuthFetcher: remoteAttestationAuthFetcher,
)
let connection = try await makeHandshakeAndOpenConnection(config)
defer { connection.disconnect(code: .normalClosure) }
@ -286,7 +296,6 @@ public class SecureValueRecovery2Impl: SecureValueRecovery {
try await self.performExposeRequest(
backup: completedBackup,
authedAccount: authMethod.authedAccount,
connection: connection,
)
completedBackup.isExposed = true
@ -335,7 +344,6 @@ public class SecureValueRecovery2Impl: SecureValueRecovery {
private func performExposeRequest(
backup: CompletedBackup,
authedAccount: AuthedAccount,
connection: SgxWebsocketConnection<SVR2WebsocketConfigurator>,
) async throws {
var exposeRequest = SVR2Proto_ExposeRequest()
@ -408,7 +416,11 @@ public class SecureValueRecovery2Impl: SecureValueRecovery {
mrEnclave: MrEnclave,
authMethod: SVR2.AuthMethod,
) async throws -> RestoreResult {
let config = SVR2WebsocketConfigurator(mrenclave: mrEnclave, authMethod: authMethod)
let config = SVR2WebsocketConfigurator(
mrenclave: mrEnclave,
authMethod: authMethod,
remoteAttestationAuthFetcher: remoteAttestationAuthFetcher,
)
do {
let connection = try await makeHandshakeAndOpenConnection(config)
defer { connection.disconnect(code: .normalClosure) }
@ -417,7 +429,6 @@ public class SecureValueRecovery2Impl: SecureValueRecovery {
mrEnclave: mrEnclave,
pin: pin,
connection: connection,
authedAccount: authMethod.authedAccount,
)
}
}
@ -426,7 +437,6 @@ public class SecureValueRecovery2Impl: SecureValueRecovery {
mrEnclave: MrEnclave,
pin: String,
connection: SgxWebsocketConnection<SVR2WebsocketConfigurator>,
authedAccount: AuthedAccount,
) async throws -> RestoreResult {
let pinHash = try hashPin(pin, forConnection: connection)
@ -462,7 +472,11 @@ public class SecureValueRecovery2Impl: SecureValueRecovery {
mrEnclave: MrEnclave,
authMethod: SVR2.AuthMethod,
) async throws {
let config = SVR2WebsocketConfigurator(mrenclave: mrEnclave, authMethod: authMethod)
let config = SVR2WebsocketConfigurator(
mrenclave: mrEnclave,
authMethod: authMethod,
remoteAttestationAuthFetcher: remoteAttestationAuthFetcher,
)
let connection = try await makeHandshakeAndOpenConnection(config)
defer { connection.disconnect(code: .normalClosure) }
await db.awaitableWrite { tx in
@ -474,14 +488,12 @@ public class SecureValueRecovery2Impl: SecureValueRecovery {
return try await self.performDeleteRequest(
mrEnclave: mrEnclave,
connection: connection,
authedAccount: authMethod.authedAccount,
)
}
private func performDeleteRequest(
mrEnclave: MrEnclave,
connection: SgxWebsocketConnection<SVR2WebsocketConfigurator>,
authedAccount: AuthedAccount,
) async throws {
var request = SVR2Proto_Request()
request.delete = SVR2Proto_DeleteRequest()

View File

@ -31,7 +31,7 @@ public protocol SgxWebsocketConfigurator {
/// Called internally in order to fetch authentication to include in the header
/// when establishing the initial websocket connection.
func fetchAuth() async throws -> RemoteAttestation.Auth
func fetchAuth() async throws -> RemoteAttestationAuth
/// Called just after starting a websocket connection in order to use the
/// client for the handshake and subsequent messages.

View File

@ -27,7 +27,7 @@ public class SgxWebsocketConnection<Configurator: SgxWebsocketConfigurator> {
public var client: Configurator.Client { fatalError("Concrete subclass must implement") }
public var auth: RemoteAttestation.Auth { fatalError("Concrete subclass must implement") }
public var auth: RemoteAttestationAuth { fatalError("Concrete subclass must implement") }
// Subclasses must implement.
func sendRequestAndReadResponse(_ request: Configurator.Request) async throws -> Configurator.Response {
@ -45,13 +45,13 @@ public class SgxWebsocketConnectionImpl<Configurator: SgxWebsocketConfigurator>:
private let webSocket: WebSocketPromise
private let configurator: Configurator
private let _client: Configurator.Client
private let _auth: RemoteAttestation.Auth
private let _auth: RemoteAttestationAuth
private init(
webSocket: WebSocketPromise,
configurator: Configurator,
client: Configurator.Client,
auth: RemoteAttestation.Auth,
auth: RemoteAttestationAuth,
) {
self.webSocket = webSocket
self.configurator = configurator
@ -62,7 +62,7 @@ public class SgxWebsocketConnectionImpl<Configurator: SgxWebsocketConfigurator>:
static func connectAndPerformHandshake(
configurator: Configurator,
auth: RemoteAttestation.Auth,
auth: RemoteAttestationAuth,
websocketFactory: WebSocketFactory,
) async throws -> SgxWebsocketConnection<Configurator> {
let webSocket = try buildSocket(
@ -95,7 +95,7 @@ public class SgxWebsocketConnectionImpl<Configurator: SgxWebsocketConfigurator>:
private static func buildSocket(
configurator: Configurator,
auth: RemoteAttestation.Auth,
auth: RemoteAttestationAuth,
websocketFactory: WebSocketFactory,
) throws -> WebSocketPromise {
let authHeaderValue = HttpHeaders.authHeaderValue(username: auth.username, password: auth.password)
@ -115,7 +115,7 @@ public class SgxWebsocketConnectionImpl<Configurator: SgxWebsocketConfigurator>:
override public var client: Configurator.Client { return _client }
override public var auth: RemoteAttestation.Auth { return _auth }
override public var auth: RemoteAttestationAuth { return _auth }
override public func sendRequestAndReadResponse(
_ request: Configurator.Request,
@ -156,9 +156,9 @@ public class MockSgxWebsocketConnection<Configurator: SgxWebsocketConfigurator>:
override public var client: Configurator.Client { return mockClient }
public var mockAuth: RemoteAttestation.Auth!
public var mockAuth: RemoteAttestationAuth!
override public var auth: RemoteAttestation.Auth { return mockAuth }
override public var auth: RemoteAttestationAuth { return mockAuth }
public var onSendRequestAndReadResponse: ((Configurator.Request) -> Promise<Configurator.Response>)?

View File

@ -28,12 +28,6 @@ final class ContactDiscoveryV2OperationTest: XCTestCase {
}
}
private class MockRemoteAttestation: ContactDiscoveryV2Operation<MockContactDiscoveryConnection>.Shims.RemoteAttestation {
func authForCDSI() async throws -> RemoteAttestation.Auth {
return RemoteAttestation.Auth(username: "", password: "")
}
}
private class MockContactDiscoveryV2PersistentState: ContactDiscoveryV2PersistentState {
var token: Data?
var prevE164s = Set<E164>()
@ -61,12 +55,28 @@ final class ContactDiscoveryV2OperationTest: XCTestCase {
// MARK: - Tests
private lazy var persistentState = MockContactDiscoveryV2PersistentState()
private let mockNetworkManager = MockNetworkManager()
var authFetchSuccessResponse: (TSRequest, NetworkManager.RetryPolicy) async throws -> HTTPResponse = { request, _ in
if request.url.absoluteString.hasSuffix("v2/directory/auth") {
return HTTPResponse(
requestUrl: request.url,
status: 200,
headers: HttpHeaders(),
bodyData: try! JSONEncoder().encode(["password": "p", "username": "u"]),
)
}
throw OWSAssertionError("")
}
/// In .oneOffUserRequest mode, we should disregard tokens entirely.
func testOneOffRequest() async throws {
let aci = Aci.randomForTesting()
let pni = Pni.randomForTesting()
mockNetworkManager.asyncRequestHandlers.append(authFetchSuccessResponse)
let remoteAttestationAuthFetcher = RemoteAttestationAuthFetcher(networkManager: mockNetworkManager)
let connection = MockContactDiscoveryConnection()
let operation = ContactDiscoveryV2Operation(
db: InMemoryDB(),
@ -74,7 +84,7 @@ final class ContactDiscoveryV2OperationTest: XCTestCase {
mode: .oneOffUserRequest,
udManager: OWSMockUDManager(),
connectionImpl: connection,
remoteAttestation: MockRemoteAttestation(),
remoteAttestationAuthFetcher: remoteAttestationAuthFetcher,
)
// Prepare the server's responses to the client's request.
@ -104,13 +114,16 @@ final class ContactDiscoveryV2OperationTest: XCTestCase {
func testNotDiscoverable() async throws {
let connection = MockContactDiscoveryConnection()
mockNetworkManager.asyncRequestHandlers.append(authFetchSuccessResponse)
let remoteAttestationAuthFetcher = RemoteAttestationAuthFetcher(networkManager: mockNetworkManager)
let operation = ContactDiscoveryV2Operation(
db: InMemoryDB(),
e164sToLookup: [try XCTUnwrap(E164("+16505550100"))],
persistentState: nil,
udManager: OWSMockUDManager(),
connectionImpl: connection,
remoteAttestation: MockRemoteAttestation(),
remoteAttestationAuthFetcher: remoteAttestationAuthFetcher,
)
// Prepare the server's responses to the client's request.
@ -131,13 +144,15 @@ final class ContactDiscoveryV2OperationTest: XCTestCase {
/// If the server reports a rate limit, we should parse "retry after".
func testRateLimitError() async throws {
let connection = MockContactDiscoveryConnection()
mockNetworkManager.asyncRequestHandlers.append(authFetchSuccessResponse)
let remoteAttestationAuthFetcher = RemoteAttestationAuthFetcher(networkManager: mockNetworkManager)
let operation = ContactDiscoveryV2Operation(
db: InMemoryDB(),
e164sToLookup: [try XCTUnwrap(E164("+16505550100"))],
persistentState: persistentState,
udManager: OWSMockUDManager(),
connectionImpl: connection,
remoteAttestation: MockRemoteAttestation(),
remoteAttestationAuthFetcher: remoteAttestationAuthFetcher,
)
// Establish the initial state.
@ -174,13 +189,15 @@ final class ContactDiscoveryV2OperationTest: XCTestCase {
/// If the server reports an invalid token, we should clear the token.
func testInvalidTokenError() async throws {
let connection = MockContactDiscoveryConnection()
mockNetworkManager.asyncRequestHandlers.append(authFetchSuccessResponse)
let remoteAttestationAuthFetcher = RemoteAttestationAuthFetcher(networkManager: mockNetworkManager)
let operation = ContactDiscoveryV2Operation(
db: InMemoryDB(),
e164sToLookup: [try XCTUnwrap(E164("+16505550100"))],
persistentState: persistentState,
udManager: OWSMockUDManager(),
connectionImpl: connection,
remoteAttestation: MockRemoteAttestation(),
remoteAttestationAuthFetcher: remoteAttestationAuthFetcher,
)
// Establish the initial state.

View File

@ -24,7 +24,7 @@ struct SVR2ConcurrencyTests {
mockConnection = MockSgxWebsocketConnection<SVR2WebsocketConfigurator>()
mockConnection.mockEnclave = TSConstants.shared.svr2Enclaves.first!
mockConnection.mockAuth = RemoteAttestation.Auth(username: "username", password: "password")
mockConnection.mockAuth = RemoteAttestationAuth(username: "username", password: "password")
mockConnectionFactory = MockSgxWebsocketConnectionFactory()
let accountKeyStore = AccountKeyStore(
@ -38,6 +38,7 @@ struct SVR2ConcurrencyTests {
db: db,
accountKeyStore: accountKeyStore,
pinHasher: MockPinHasher(),
remoteAttestationAuthFetcher: RemoteAttestationAuthFetcher(networkManager: MockNetworkManager()),
storageServiceManager: FakeStorageServiceManager(),
svrLocalStorage: localStorage,
tsConstants: TSConstants.shared,
@ -87,7 +88,7 @@ struct SVR2ConcurrencyTests {
mockConnectionFactory.setOnConnectAndPerformHandshake { (_: SVR2WebsocketConfigurator) in
let mockConnection = MockSgxWebsocketConnection<SVR2WebsocketConfigurator>()
mockConnection.mockEnclave = TSConstants.shared.svr2Enclaves.first!
mockConnection.mockAuth = RemoteAttestation.Auth(username: "username", password: "password")
mockConnection.mockAuth = RemoteAttestationAuth(username: "username", password: "password")
mockConnection.onSendRequestAndReadResponse = onSendRequestAndReadResponse
return mockConnection
}
@ -121,11 +122,11 @@ struct SVR2ConcurrencyTests {
let firstMockConnection = MockSgxWebsocketConnection<SVR2WebsocketConfigurator>()
firstMockConnection.mockEnclave = TSConstants.shared.svr2Enclaves.first!
firstMockConnection.mockAuth = RemoteAttestation.Auth(username: "username", password: "password")
firstMockConnection.mockAuth = RemoteAttestationAuth(username: "username", password: "password")
let secondMockConnection = MockSgxWebsocketConnection<SVR2WebsocketConfigurator>()
secondMockConnection.mockEnclave = TSConstants.shared.svr2Enclaves.first!
secondMockConnection.mockAuth = RemoteAttestation.Auth(username: "username2", password: "password2")
secondMockConnection.mockAuth = RemoteAttestationAuth(username: "username2", password: "password2")
var numOpenedConnections = 0
mockConnectionFactory.setOnConnectAndPerformHandshake { (_: SVR2WebsocketConfigurator) in

View File

@ -33,7 +33,7 @@ class SecureValueRecovery2Tests: XCTestCase {
localStorage = SVRLocalStorage()
let mockConnection = MockSgxWebsocketConnection<SVR2WebsocketConfigurator>()
mockConnection.mockAuth = RemoteAttestation.Auth(username: "username", password: "password")
mockConnection.mockAuth = RemoteAttestationAuth(username: "username", password: "password")
self.mockConnection = mockConnection
mockConnectionFactory = MockSgxWebsocketConnectionFactory()
@ -45,6 +45,7 @@ class SecureValueRecovery2Tests: XCTestCase {
db: db,
accountKeyStore: accountKeyStore,
pinHasher: MockPinHasher(),
remoteAttestationAuthFetcher: RemoteAttestationAuthFetcher(networkManager: MockNetworkManager()),
storageServiceManager: FakeStorageServiceManager(),
svrLocalStorage: localStorage,
tsConstants: mockTSConstants,
@ -55,7 +56,7 @@ class SecureValueRecovery2Tests: XCTestCase {
@MainActor
func testMigration() async throws {
// Set up the connections to both the old and new enclaves.
let mockAuth = RemoteAttestation.Auth(username: "username", password: "password")
let mockAuth = RemoteAttestationAuth(username: "username", password: "password")
let oldEnclave = MrEnclave("0000000000000000000000000000000000000000000000000000000000000000")
let oldEnclaveConnection = MockSgxWebsocketConnection<SVR2WebsocketConfigurator>()
@ -185,7 +186,7 @@ class SecureValueRecovery2Tests: XCTestCase {
@MainActor
func testMigration_forgottenEnclave() async throws {
// Set up the connections to both the old and new enclaves.
let mockAuth = RemoteAttestation.Auth(username: "username", password: "password")
let mockAuth = RemoteAttestationAuth(username: "username", password: "password")
let oldEnclave = MrEnclave("0000000000000000000000000000000000000000000000000000000000000000")
let oldEnclaveConnection = MockSgxWebsocketConnection<SVR2WebsocketConfigurator>()