diff --git a/Signal.xcodeproj/project.pbxproj b/Signal.xcodeproj/project.pbxproj index 62266d7760..7b7f2d8ab2 100644 --- a/Signal.xcodeproj/project.pbxproj +++ b/Signal.xcodeproj/project.pbxproj @@ -3822,7 +3822,7 @@ F9C5CCA3289453B300548EEE /* StorageService.pb.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5C9B8289453B100548EEE /* StorageService.pb.swift */; }; F9C5CCA4289453B300548EEE /* SSKProto+OWS.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5C9B9289453B100548EEE /* SSKProto+OWS.swift */; }; F9C5CCAC289453B300548EEE /* PreKeyManager.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5C9C2289453B100548EEE /* PreKeyManager.swift */; }; - F9C5CCB0289453B300548EEE /* RemoteAttestation.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5C9C7289453B100548EEE /* RemoteAttestation.swift */; }; + F9C5CCB0289453B300548EEE /* RemoteAttestationAuthFetcher.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5C9C7289453B100548EEE /* RemoteAttestationAuthFetcher.swift */; }; F9C5CCC0289453B300548EEE /* ContactDiscoveryTask.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5C9D9289453B100548EEE /* ContactDiscoveryTask.swift */; }; F9C5CCC3289453B300548EEE /* ContactDiscoveryError.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5C9DC289453B100548EEE /* ContactDiscoveryError.swift */; }; F9C5CCC5289453B300548EEE /* SignalAccount.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5C9DE289453B100548EEE /* SignalAccount.swift */; }; @@ -8136,7 +8136,7 @@ F9C5C9B8289453B100548EEE /* StorageService.pb.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = StorageService.pb.swift; sourceTree = ""; }; F9C5C9B9289453B100548EEE /* SSKProto+OWS.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = "SSKProto+OWS.swift"; sourceTree = ""; }; F9C5C9C2289453B100548EEE /* PreKeyManager.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PreKeyManager.swift; sourceTree = ""; }; - F9C5C9C7289453B100548EEE /* RemoteAttestation.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = RemoteAttestation.swift; sourceTree = ""; }; + F9C5C9C7289453B100548EEE /* RemoteAttestationAuthFetcher.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = RemoteAttestationAuthFetcher.swift; sourceTree = ""; }; F9C5C9D9289453B100548EEE /* ContactDiscoveryTask.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ContactDiscoveryTask.swift; sourceTree = ""; }; F9C5C9DC289453B100548EEE /* ContactDiscoveryError.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ContactDiscoveryError.swift; sourceTree = ""; }; F9C5C9DE289453B100548EEE /* SignalAccount.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = SignalAccount.swift; sourceTree = ""; }; @@ -15367,7 +15367,7 @@ F9D5BFCC2979A017001737E5 /* OWSRequestFactory+Spam.swift */, D95C39E7296DEBFB00A9DA23 /* OWSRequestFactory+Usernames.swift */, F9C5CAE2289453B200548EEE /* OWSRequestFactory.swift */, - F9C5C9C7289453B100548EEE /* RemoteAttestation.swift */, + F9C5C9C7289453B100548EEE /* RemoteAttestationAuthFetcher.swift */, 66C2B1302A05D28A008DDE72 /* TSRequest.swift */, ); path = Requests; @@ -19723,7 +19723,7 @@ 6646573B2AC388C70099DE1C /* RegistrationStateChangeManager.swift in Sources */, 6646573D2AC3894D0099DE1C /* RegistrationStateChangeManagerImpl.swift in Sources */, 040506FC2F7FE3DB0078B769 /* RemoteAnnouncementModel.swift in Sources */, - F9C5CCB0289453B300548EEE /* RemoteAttestation.swift in Sources */, + F9C5CCB0289453B300548EEE /* RemoteAttestationAuthFetcher.swift in Sources */, F9C5CE17289453B400548EEE /* RemoteConfigManager.swift in Sources */, D98DD86028EE53B00089333E /* RemoteMegaphoneModel.swift in Sources */, 040507012F804C240078B769 /* RemoteReleaseNotesService.swift in Sources */, diff --git a/Signal/test/Registration/RegistrationCoordinatorTest.swift b/Signal/test/Registration/RegistrationCoordinatorTest.swift index a273cc6385..249097918e 100644 --- a/Signal/test/Registration/RegistrationCoordinatorTest.swift +++ b/Signal/test/Registration/RegistrationCoordinatorTest.swift @@ -3364,9 +3364,9 @@ public class RegistrationCoordinatorTest { // Put some auth credentials in storage. let svr2CredentialCandidates: [SVR2AuthCredential] = [ Stubs.svr2AuthCredential, - SVR2AuthCredential(credential: RemoteAttestation.Auth(username: "aaaa", password: "abc")), - SVR2AuthCredential(credential: RemoteAttestation.Auth(username: "zzzz", password: "xyz")), - SVR2AuthCredential(credential: RemoteAttestation.Auth(username: "0000", password: "123")), + SVR2AuthCredential(credential: RemoteAttestationAuth(username: "aaaa", password: "abc")), + SVR2AuthCredential(credential: RemoteAttestationAuth(username: "zzzz", password: "xyz")), + SVR2AuthCredential(credential: RemoteAttestationAuth(username: "0000", password: "123")), ] svrAuthCredentialStore.svr2Dict = Dictionary(grouping: svr2CredentialCandidates, by: \.credential.username).mapValues { $0.first! } @@ -3398,7 +3398,7 @@ public class RegistrationCoordinatorTest { static let aci = Aci.randomForTesting() static let pinCode = "1234" - static let svr2AuthCredential = SVR2AuthCredential(credential: RemoteAttestation.Auth(username: "xxx", password: "yyy")) + static let svr2AuthCredential = SVR2AuthCredential(credential: RemoteAttestationAuth(username: "xxx", password: "yyy")) static let captchaToken = "captchaToken" static let apnsToken = "apnsToken" diff --git a/SignalServiceKit/Contacts/Discovery/ContactDiscoveryManager.swift b/SignalServiceKit/Contacts/Discovery/ContactDiscoveryManager.swift index 00f7034b04..b7b2e9d7e6 100644 --- a/SignalServiceKit/Contacts/Discovery/ContactDiscoveryManager.swift +++ b/SignalServiceKit/Contacts/Discovery/ContactDiscoveryManager.swift @@ -104,6 +104,7 @@ public final class ContactDiscoveryManagerImpl: ContactDiscoveryManager { recipientFetcher: RecipientFetcher, recipientManager: any SignalRecipientManager, recipientMerger: RecipientMerger, + remoteAttestationAuthFetcher: RemoteAttestationAuthFetcher, tsAccountManager: TSAccountManager, udManager: OWSUDManager, libsignalNet: Net, @@ -115,6 +116,7 @@ public final class ContactDiscoveryManagerImpl: ContactDiscoveryManager { recipientFetcher: recipientFetcher, recipientManager: recipientManager, recipientMerger: recipientMerger, + remoteAttestationAuthFetcher: remoteAttestationAuthFetcher, tsAccountManager: tsAccountManager, udManager: udManager, libsignalNet: libsignalNet, diff --git a/SignalServiceKit/Contacts/Discovery/ContactDiscoveryTask.swift b/SignalServiceKit/Contacts/Discovery/ContactDiscoveryTask.swift index 02ebdc30ca..fa52cd3cde 100644 --- a/SignalServiceKit/Contacts/Discovery/ContactDiscoveryTask.swift +++ b/SignalServiceKit/Contacts/Discovery/ContactDiscoveryTask.swift @@ -17,6 +17,7 @@ final class ContactDiscoveryTaskQueueImpl: ContactDiscoveryTaskQueue { private let recipientFetcher: RecipientFetcher private let recipientManager: any SignalRecipientManager private let recipientMerger: RecipientMerger + private let remoteAttestationAuthFetcher: RemoteAttestationAuthFetcher private let tsAccountManager: TSAccountManager private let udManager: OWSUDManager private let libsignalNet: Net @@ -27,6 +28,7 @@ final class ContactDiscoveryTaskQueueImpl: ContactDiscoveryTaskQueue { recipientFetcher: RecipientFetcher, recipientManager: any SignalRecipientManager, recipientMerger: RecipientMerger, + remoteAttestationAuthFetcher: RemoteAttestationAuthFetcher, tsAccountManager: TSAccountManager, udManager: OWSUDManager, libsignalNet: Net, @@ -36,6 +38,7 @@ final class ContactDiscoveryTaskQueueImpl: ContactDiscoveryTaskQueue { self.recipientFetcher = recipientFetcher self.recipientManager = recipientManager self.recipientMerger = recipientMerger + self.remoteAttestationAuthFetcher = remoteAttestationAuthFetcher self.tsAccountManager = tsAccountManager self.udManager = udManager self.libsignalNet = libsignalNet @@ -53,7 +56,7 @@ final class ContactDiscoveryTaskQueueImpl: ContactDiscoveryTaskQueue { mode: mode, udManager: udManager, connectionImpl: libsignalNet, - remoteAttestation: ContactDiscoveryV2Operation.Wrappers.RemoteAttestation(), + remoteAttestationAuthFetcher: remoteAttestationAuthFetcher, ).perform() return try await self.processResults(requestedPhoneNumbers: e164s, discoveryResults: discoveryResults) diff --git a/SignalServiceKit/Contacts/Discovery/ContactDiscoveryV2Operation.swift b/SignalServiceKit/Contacts/Discovery/ContactDiscoveryV2Operation.swift index 6fe79d6e5d..38591ad833 100644 --- a/SignalServiceKit/Contacts/Discovery/ContactDiscoveryV2Operation.swift +++ b/SignalServiceKit/Contacts/Discovery/ContactDiscoveryV2Operation.swift @@ -101,7 +101,7 @@ final class ContactDiscoveryV2Operation [ContactDiscoveryResult] { do { - let cdsiAuth = try await self.remoteAttestation.authForCDSI() + let cdsiAuth = try await self.remoteAttestationAuthFetcher.fetchAuth( + forService: .cdsi, + chatServiceAuth: .implicit(), + ) let request = try self.buildRequest() let auth = LibSignalClient.Auth(username: cdsiAuth.username, password: cdsiAuth.password) let tokenResult = try await self.connectionImpl.performRequest(request, auth: auth) @@ -322,27 +325,3 @@ private class ContactDiscoveryV2PersistentStateImpl: ContactDiscoveryV2Persisten } } } - -// MARK: - Shims - -extension ContactDiscoveryV2Operation { - enum Shims { - typealias RemoteAttestation = _ContactDiscoveryV2Operation_RemoteAttestationShim - } - - enum Wrappers { - typealias RemoteAttestation = _ContactDiscoveryV2Operation_RemoteAttestationWrapper - } -} - -protocol _ContactDiscoveryV2Operation_RemoteAttestationShim { - func authForCDSI() async throws -> RemoteAttestation.Auth -} - -class _ContactDiscoveryV2Operation_RemoteAttestationWrapper: _ContactDiscoveryV2Operation_RemoteAttestationShim { - init() {} - - func authForCDSI() async throws -> RemoteAttestation.Auth { - return try await RemoteAttestation.authForCDSI() - } -} diff --git a/SignalServiceKit/Environment/AppSetup.swift b/SignalServiceKit/Environment/AppSetup.swift index ef4ae375a8..934a150445 100644 --- a/SignalServiceKit/Environment/AppSetup.swift +++ b/SignalServiceKit/Environment/AppSetup.swift @@ -437,12 +437,14 @@ extension AppSetup.GlobalsContinuation { networkManager: networkManager, ) + let remoteAttestationAuthFetcher = RemoteAttestationAuthFetcher(networkManager: networkManager) let svr = SecureValueRecovery2Impl( connectionFactory: SgxWebsocketConnectionFactoryImpl(websocketFactory: webSocketFactory), credentialStorage: svrCredentialStorage, db: db, accountKeyStore: accountKeyStore, pinHasher: LibSignalPinHasher(), + remoteAttestationAuthFetcher: remoteAttestationAuthFetcher, storageServiceManager: storageServiceManager, svrLocalStorage: svrLocalStorage, tsConstants: tsConstants, @@ -1859,6 +1861,7 @@ extension AppSetup.GlobalsContinuation { recipientFetcher: recipientFetcher, recipientManager: recipientManager, recipientMerger: recipientMerger, + remoteAttestationAuthFetcher: remoteAttestationAuthFetcher, tsAccountManager: tsAccountManager, udManager: udManager, libsignalNet: libsignalNet, diff --git a/SignalServiceKit/Network/API/Requests/Registration/RegistrationServiceResponses.swift b/SignalServiceKit/Network/API/Requests/Registration/RegistrationServiceResponses.swift index 44ad6330ea..ee59edd824 100644 --- a/SignalServiceKit/Network/API/Requests/Registration/RegistrationServiceResponses.swift +++ b/SignalServiceKit/Network/API/Requests/Registration/RegistrationServiceResponses.swift @@ -328,7 +328,7 @@ public enum RegistrationServiceResponses { public init(from decoder: Decoder) throws { let container = try decoder.container(keyedBy: CodingKeys.self) timeRemainingMs = try container.decode(Int.self, forKey: .timeRemainingMs) - let svr2Credential = try container.decode(RemoteAttestation.Auth.self, forKey: .svr2AuthCredential) + let svr2Credential = try container.decode(RemoteAttestationAuth.self, forKey: .svr2AuthCredential) self.svr2AuthCredential = .init(credential: svr2Credential) } } diff --git a/SignalServiceKit/Network/API/Requests/RemoteAttestation.swift b/SignalServiceKit/Network/API/Requests/RemoteAttestation.swift deleted file mode 100644 index 5ea01456ee..0000000000 --- a/SignalServiceKit/Network/API/Requests/RemoteAttestation.swift +++ /dev/null @@ -1,97 +0,0 @@ -// -// Copyright 2020 Signal Messenger, LLC -// SPDX-License-Identifier: AGPL-3.0-only -// - -import Foundation -import LibSignalClient - -public enum RemoteAttestation {} - -// MARK: - CSDI - -extension RemoteAttestation { - static func authForCDSI() async throws -> Auth { - return try await Auth.fetch(forService: .cdsi, auth: .implicit()) - } -} - -// MARK: - SVR2 - -extension RemoteAttestation { - static func authForSVR2(chatServiceAuth: ChatServiceAuth) async throws -> Auth { - return try await Auth.fetch(forService: .svr2, auth: chatServiceAuth) - } -} - -// MARK: - Errors - -public extension RemoteAttestation { - enum Error: Swift.Error { - case assertion(reason: String) - } -} - -private func attestationError(reason: String) -> RemoteAttestation.Error { - owsFailDebug("Error: \(reason)") - return .assertion(reason: reason) -} - -// MARK: - Auth - -public extension RemoteAttestation { - struct Auth: Equatable, Codable { - public let username: String - public let password: String - - public init(authParamsDict: [String: Any]) throws { - guard let password = authParamsDict["password"] as? String, !password.isEmpty else { - throw attestationError(reason: "missing or empty password") - } - - guard let username = authParamsDict["username"] as? String, !username.isEmpty else { - throw attestationError(reason: "missing or empty username") - } - - self.init(username: username, password: password) - } - - public init(username: String, password: String) { - self.username = username - self.password = password - } - } -} - -private extension RemoteAttestation.Auth { - static func fetch( - forService service: RemoteAttestation.Service, - auth: ChatServiceAuth, - ) async throws -> RemoteAttestation.Auth { - var request = service.authRequest() - request.auth = .identified(auth) - let response = try await SSKEnvironment.shared.networkManagerRef.asyncRequest(request) - - guard let authParamsDict = response.responseBodyDict else { - throw attestationError(reason: "Missing or invalid JSON") - } - - return try RemoteAttestation.Auth(authParamsDict: authParamsDict) - } -} - -// MARK: - Service - -private extension RemoteAttestation { - enum Service { - case cdsi - case svr2 - - func authRequest() -> TSRequest { - switch self { - case .cdsi: return OWSRequestFactory.remoteAttestationAuthRequestForCDSI() - case .svr2: return OWSRequestFactory.remoteAttestationAuthRequestForSVR2() - } - } - } -} diff --git a/SignalServiceKit/Network/API/Requests/RemoteAttestationAuthFetcher.swift b/SignalServiceKit/Network/API/Requests/RemoteAttestationAuthFetcher.swift new file mode 100644 index 0000000000..79bfbefe4d --- /dev/null +++ b/SignalServiceKit/Network/API/Requests/RemoteAttestationAuthFetcher.swift @@ -0,0 +1,69 @@ +// +// Copyright 2020 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only +// + +import Foundation +import LibSignalClient + +// MARK: - Auth + +public struct RemoteAttestationAuth: Equatable, Codable { + public let username: String + public let password: String + + init(authParamsDict: [String: Any]) throws { + guard let password = authParamsDict["password"] as? String, !password.isEmpty else { + throw OWSAssertionError("missing or empty password") + } + + guard let username = authParamsDict["username"] as? String, !username.isEmpty else { + throw OWSAssertionError("missing or empty username") + } + + self.init(username: username, password: password) + } + + public init(username: String, password: String) { + self.username = username + self.password = password + } +} + +public struct RemoteAttestationAuthFetcher { + + let networkManager: any NetworkManagerProtocol + + public init(networkManager: any NetworkManagerProtocol) { + self.networkManager = networkManager + } + + func fetchAuth( + forService service: Service, + chatServiceAuth: ChatServiceAuth, + ) async throws -> RemoteAttestationAuth { + var request = service.authRequest() + request.auth = .identified(chatServiceAuth) + let response = try await networkManager.asyncRequest(request) + + guard let authParamsDict = response.responseBodyDict else { + throw OWSAssertionError("Missing or invalid JSON") + } + + return try RemoteAttestationAuth(authParamsDict: authParamsDict) + } + + // MARK: - Service + + public enum Service { + case cdsi + case svr2 + + func authRequest() -> TSRequest { + switch self { + case .cdsi: return OWSRequestFactory.remoteAttestationAuthRequestForCDSI() + case .svr2: return OWSRequestFactory.remoteAttestationAuthRequestForSVR2() + } + } + } +} diff --git a/SignalServiceKit/SecureValueRecovery/SVR2AuthCredential.swift b/SignalServiceKit/SecureValueRecovery/SVR2AuthCredential.swift index f4e68869e5..00a0466724 100644 --- a/SignalServiceKit/SecureValueRecovery/SVR2AuthCredential.swift +++ b/SignalServiceKit/SecureValueRecovery/SVR2AuthCredential.swift @@ -8,9 +8,9 @@ import Foundation // Transparent wrapper that exists purely to make it clear to readers // that the credential must be for SVR2, not an arbitrary RemoteAttestation.Auth. public struct SVR2AuthCredential: Equatable, Codable { - public let credential: RemoteAttestation.Auth + public let credential: RemoteAttestationAuth - public init(credential: RemoteAttestation.Auth) { + public init(credential: RemoteAttestationAuth) { self.credential = credential } } diff --git a/SignalServiceKit/SecureValueRecovery/SVR2WebsocketConfigurator.swift b/SignalServiceKit/SecureValueRecovery/SVR2WebsocketConfigurator.swift index 6583d4d467..80831ff1bd 100644 --- a/SignalServiceKit/SecureValueRecovery/SVR2WebsocketConfigurator.swift +++ b/SignalServiceKit/SecureValueRecovery/SVR2WebsocketConfigurator.swift @@ -14,10 +14,16 @@ class SVR2WebsocketConfigurator: SgxWebsocketConfigurator { let mrenclave: MrEnclave var authMethod: SVR2.AuthMethod + let remoteAttestationAuthFetcher: RemoteAttestationAuthFetcher - init(mrenclave: MrEnclave, authMethod: SVR2.AuthMethod) { + init( + mrenclave: MrEnclave, + authMethod: SVR2.AuthMethod, + remoteAttestationAuthFetcher: RemoteAttestationAuthFetcher, + ) { self.mrenclave = mrenclave self.authMethod = authMethod + self.remoteAttestationAuthFetcher = remoteAttestationAuthFetcher } static var signalServiceType: SignalServiceType { .svr2 } @@ -26,14 +32,20 @@ class SVR2WebsocketConfigurator: SgxWebsocketConfigurator { return "v1/\(mrenclaveString)" } - func fetchAuth() async throws -> RemoteAttestation.Auth { + func fetchAuth() async throws -> RemoteAttestationAuth { switch authMethod { case .svrAuth(let credential, _): return credential.credential case .chatServerAuth(let authedAccount): - return try await RemoteAttestation.authForSVR2(chatServiceAuth: authedAccount.chatServiceAuth) + return try await remoteAttestationAuthFetcher.fetchAuth( + forService: .svr2, + chatServiceAuth: authedAccount.chatServiceAuth, + ) case .implicit: - return try await RemoteAttestation.authForSVR2(chatServiceAuth: .implicit()) + return try await remoteAttestationAuthFetcher.fetchAuth( + forService: .svr2, + chatServiceAuth: .implicit(), + ) } } diff --git a/SignalServiceKit/SecureValueRecovery/SVRAuthCredentialStorageImpl.swift b/SignalServiceKit/SecureValueRecovery/SVRAuthCredentialStorageImpl.swift index a3d4fb66c5..2867de884a 100644 --- a/SignalServiceKit/SecureValueRecovery/SVRAuthCredentialStorageImpl.swift +++ b/SignalServiceKit/SecureValueRecovery/SVRAuthCredentialStorageImpl.swift @@ -223,8 +223,8 @@ public class SVRAuthCredentialStorageImpl: SVRAuthCredentialStorage { ) } - private var credential: RemoteAttestation.Auth { - return RemoteAttestation.Auth(username: username, password: password) + private var credential: RemoteAttestationAuth { + return RemoteAttestationAuth(username: username, password: password) } func toSVR2Credential() -> SVR2AuthCredential { diff --git a/SignalServiceKit/SecureValueRecovery/SecureValueRecovery2Impl.swift b/SignalServiceKit/SecureValueRecovery/SecureValueRecovery2Impl.swift index 996788a559..79d51950f0 100644 --- a/SignalServiceKit/SecureValueRecovery/SecureValueRecovery2Impl.swift +++ b/SignalServiceKit/SecureValueRecovery/SecureValueRecovery2Impl.swift @@ -15,6 +15,7 @@ public class SecureValueRecovery2Impl: SecureValueRecovery { private let db: any DB private let accountKeyStore: AccountKeyStore private let localStorage: SVRLocalStorage + private let remoteAttestationAuthFetcher: RemoteAttestationAuthFetcher private let storageServiceManager: StorageServiceManager private let tsConstants: TSConstantsProtocol private let twoFAManager: SVR2.Shims.OWS2FAManager @@ -25,6 +26,7 @@ public class SecureValueRecovery2Impl: SecureValueRecovery { db: any DB, accountKeyStore: AccountKeyStore, pinHasher: any SVR2PinHasher, + remoteAttestationAuthFetcher: RemoteAttestationAuthFetcher, storageServiceManager: StorageServiceManager, svrLocalStorage: SVRLocalStorage, tsConstants: TSConstantsProtocol, @@ -36,6 +38,7 @@ public class SecureValueRecovery2Impl: SecureValueRecovery { self.accountKeyStore = accountKeyStore self.localStorage = svrLocalStorage self.pinHasher = pinHasher + self.remoteAttestationAuthFetcher = remoteAttestationAuthFetcher self.storageServiceManager = storageServiceManager self.tsConstants = tsConstants self.twoFAManager = twoFAManager @@ -52,7 +55,10 @@ public class SecureValueRecovery2Impl: SecureValueRecovery { } // Force refresh a credential, even if we have one cached, to ensure we // have a fresh credential to back up. - let credential = try await RemoteAttestation.authForSVR2(chatServiceAuth: .implicit()) + let credential = try await remoteAttestationAuthFetcher.fetchAuth( + forService: .svr2, + chatServiceAuth: .implicit(), + ) await db.awaitableWrite { tx in credentialStorage.storeAuthCredentialForCurrentUsername( SVR2AuthCredential(credential: credential), @@ -245,7 +251,11 @@ public class SecureValueRecovery2Impl: SecureValueRecovery { return } - let config = SVR2WebsocketConfigurator(mrenclave: mrEnclave, authMethod: authMethod) + let config = SVR2WebsocketConfigurator( + mrenclave: mrEnclave, + authMethod: authMethod, + remoteAttestationAuthFetcher: remoteAttestationAuthFetcher, + ) let connection = try await makeHandshakeAndOpenConnection(config) defer { connection.disconnect(code: .normalClosure) } @@ -286,7 +296,6 @@ public class SecureValueRecovery2Impl: SecureValueRecovery { try await self.performExposeRequest( backup: completedBackup, - authedAccount: authMethod.authedAccount, connection: connection, ) completedBackup.isExposed = true @@ -335,7 +344,6 @@ public class SecureValueRecovery2Impl: SecureValueRecovery { private func performExposeRequest( backup: CompletedBackup, - authedAccount: AuthedAccount, connection: SgxWebsocketConnection, ) async throws { var exposeRequest = SVR2Proto_ExposeRequest() @@ -408,7 +416,11 @@ public class SecureValueRecovery2Impl: SecureValueRecovery { mrEnclave: MrEnclave, authMethod: SVR2.AuthMethod, ) async throws -> RestoreResult { - let config = SVR2WebsocketConfigurator(mrenclave: mrEnclave, authMethod: authMethod) + let config = SVR2WebsocketConfigurator( + mrenclave: mrEnclave, + authMethod: authMethod, + remoteAttestationAuthFetcher: remoteAttestationAuthFetcher, + ) do { let connection = try await makeHandshakeAndOpenConnection(config) defer { connection.disconnect(code: .normalClosure) } @@ -417,7 +429,6 @@ public class SecureValueRecovery2Impl: SecureValueRecovery { mrEnclave: mrEnclave, pin: pin, connection: connection, - authedAccount: authMethod.authedAccount, ) } } @@ -426,7 +437,6 @@ public class SecureValueRecovery2Impl: SecureValueRecovery { mrEnclave: MrEnclave, pin: String, connection: SgxWebsocketConnection, - authedAccount: AuthedAccount, ) async throws -> RestoreResult { let pinHash = try hashPin(pin, forConnection: connection) @@ -462,7 +472,11 @@ public class SecureValueRecovery2Impl: SecureValueRecovery { mrEnclave: MrEnclave, authMethod: SVR2.AuthMethod, ) async throws { - let config = SVR2WebsocketConfigurator(mrenclave: mrEnclave, authMethod: authMethod) + let config = SVR2WebsocketConfigurator( + mrenclave: mrEnclave, + authMethod: authMethod, + remoteAttestationAuthFetcher: remoteAttestationAuthFetcher, + ) let connection = try await makeHandshakeAndOpenConnection(config) defer { connection.disconnect(code: .normalClosure) } await db.awaitableWrite { tx in @@ -474,14 +488,12 @@ public class SecureValueRecovery2Impl: SecureValueRecovery { return try await self.performDeleteRequest( mrEnclave: mrEnclave, connection: connection, - authedAccount: authMethod.authedAccount, ) } private func performDeleteRequest( mrEnclave: MrEnclave, connection: SgxWebsocketConnection, - authedAccount: AuthedAccount, ) async throws { var request = SVR2Proto_Request() request.delete = SVR2Proto_DeleteRequest() diff --git a/SignalServiceKit/SecureValueRecovery/SgxWebsocketConfigurator.swift b/SignalServiceKit/SecureValueRecovery/SgxWebsocketConfigurator.swift index f2066163c8..a08246c7bd 100644 --- a/SignalServiceKit/SecureValueRecovery/SgxWebsocketConfigurator.swift +++ b/SignalServiceKit/SecureValueRecovery/SgxWebsocketConfigurator.swift @@ -31,7 +31,7 @@ public protocol SgxWebsocketConfigurator { /// Called internally in order to fetch authentication to include in the header /// when establishing the initial websocket connection. - func fetchAuth() async throws -> RemoteAttestation.Auth + func fetchAuth() async throws -> RemoteAttestationAuth /// Called just after starting a websocket connection in order to use the /// client for the handshake and subsequent messages. diff --git a/SignalServiceKit/SecureValueRecovery/SgxWebsocketConnection.swift b/SignalServiceKit/SecureValueRecovery/SgxWebsocketConnection.swift index 5731b061b8..d3b5e7db34 100644 --- a/SignalServiceKit/SecureValueRecovery/SgxWebsocketConnection.swift +++ b/SignalServiceKit/SecureValueRecovery/SgxWebsocketConnection.swift @@ -27,7 +27,7 @@ public class SgxWebsocketConnection { public var client: Configurator.Client { fatalError("Concrete subclass must implement") } - public var auth: RemoteAttestation.Auth { fatalError("Concrete subclass must implement") } + public var auth: RemoteAttestationAuth { fatalError("Concrete subclass must implement") } // Subclasses must implement. func sendRequestAndReadResponse(_ request: Configurator.Request) async throws -> Configurator.Response { @@ -45,13 +45,13 @@ public class SgxWebsocketConnectionImpl: private let webSocket: WebSocketPromise private let configurator: Configurator private let _client: Configurator.Client - private let _auth: RemoteAttestation.Auth + private let _auth: RemoteAttestationAuth private init( webSocket: WebSocketPromise, configurator: Configurator, client: Configurator.Client, - auth: RemoteAttestation.Auth, + auth: RemoteAttestationAuth, ) { self.webSocket = webSocket self.configurator = configurator @@ -62,7 +62,7 @@ public class SgxWebsocketConnectionImpl: static func connectAndPerformHandshake( configurator: Configurator, - auth: RemoteAttestation.Auth, + auth: RemoteAttestationAuth, websocketFactory: WebSocketFactory, ) async throws -> SgxWebsocketConnection { let webSocket = try buildSocket( @@ -95,7 +95,7 @@ public class SgxWebsocketConnectionImpl: private static func buildSocket( configurator: Configurator, - auth: RemoteAttestation.Auth, + auth: RemoteAttestationAuth, websocketFactory: WebSocketFactory, ) throws -> WebSocketPromise { let authHeaderValue = HttpHeaders.authHeaderValue(username: auth.username, password: auth.password) @@ -115,7 +115,7 @@ public class SgxWebsocketConnectionImpl: override public var client: Configurator.Client { return _client } - override public var auth: RemoteAttestation.Auth { return _auth } + override public var auth: RemoteAttestationAuth { return _auth } override public func sendRequestAndReadResponse( _ request: Configurator.Request, @@ -156,9 +156,9 @@ public class MockSgxWebsocketConnection: override public var client: Configurator.Client { return mockClient } - public var mockAuth: RemoteAttestation.Auth! + public var mockAuth: RemoteAttestationAuth! - override public var auth: RemoteAttestation.Auth { return mockAuth } + override public var auth: RemoteAttestationAuth { return mockAuth } public var onSendRequestAndReadResponse: ((Configurator.Request) -> Promise)? diff --git a/SignalServiceKit/tests/Contacts/Discovery/ContactDiscoveryV2OperationTest.swift b/SignalServiceKit/tests/Contacts/Discovery/ContactDiscoveryV2OperationTest.swift index 37008666fc..3e3eb41ddb 100644 --- a/SignalServiceKit/tests/Contacts/Discovery/ContactDiscoveryV2OperationTest.swift +++ b/SignalServiceKit/tests/Contacts/Discovery/ContactDiscoveryV2OperationTest.swift @@ -28,12 +28,6 @@ final class ContactDiscoveryV2OperationTest: XCTestCase { } } - private class MockRemoteAttestation: ContactDiscoveryV2Operation.Shims.RemoteAttestation { - func authForCDSI() async throws -> RemoteAttestation.Auth { - return RemoteAttestation.Auth(username: "", password: "") - } - } - private class MockContactDiscoveryV2PersistentState: ContactDiscoveryV2PersistentState { var token: Data? var prevE164s = Set() @@ -61,12 +55,28 @@ final class ContactDiscoveryV2OperationTest: XCTestCase { // MARK: - Tests private lazy var persistentState = MockContactDiscoveryV2PersistentState() + private let mockNetworkManager = MockNetworkManager() + + var authFetchSuccessResponse: (TSRequest, NetworkManager.RetryPolicy) async throws -> HTTPResponse = { request, _ in + if request.url.absoluteString.hasSuffix("v2/directory/auth") { + return HTTPResponse( + requestUrl: request.url, + status: 200, + headers: HttpHeaders(), + bodyData: try! JSONEncoder().encode(["password": "p", "username": "u"]), + ) + } + throw OWSAssertionError("") + } /// In .oneOffUserRequest mode, we should disregard tokens entirely. func testOneOffRequest() async throws { let aci = Aci.randomForTesting() let pni = Pni.randomForTesting() + mockNetworkManager.asyncRequestHandlers.append(authFetchSuccessResponse) + let remoteAttestationAuthFetcher = RemoteAttestationAuthFetcher(networkManager: mockNetworkManager) + let connection = MockContactDiscoveryConnection() let operation = ContactDiscoveryV2Operation( db: InMemoryDB(), @@ -74,7 +84,7 @@ final class ContactDiscoveryV2OperationTest: XCTestCase { mode: .oneOffUserRequest, udManager: OWSMockUDManager(), connectionImpl: connection, - remoteAttestation: MockRemoteAttestation(), + remoteAttestationAuthFetcher: remoteAttestationAuthFetcher, ) // Prepare the server's responses to the client's request. @@ -104,13 +114,16 @@ final class ContactDiscoveryV2OperationTest: XCTestCase { func testNotDiscoverable() async throws { let connection = MockContactDiscoveryConnection() + mockNetworkManager.asyncRequestHandlers.append(authFetchSuccessResponse) + let remoteAttestationAuthFetcher = RemoteAttestationAuthFetcher(networkManager: mockNetworkManager) + let operation = ContactDiscoveryV2Operation( db: InMemoryDB(), e164sToLookup: [try XCTUnwrap(E164("+16505550100"))], persistentState: nil, udManager: OWSMockUDManager(), connectionImpl: connection, - remoteAttestation: MockRemoteAttestation(), + remoteAttestationAuthFetcher: remoteAttestationAuthFetcher, ) // Prepare the server's responses to the client's request. @@ -131,13 +144,15 @@ final class ContactDiscoveryV2OperationTest: XCTestCase { /// If the server reports a rate limit, we should parse "retry after". func testRateLimitError() async throws { let connection = MockContactDiscoveryConnection() + mockNetworkManager.asyncRequestHandlers.append(authFetchSuccessResponse) + let remoteAttestationAuthFetcher = RemoteAttestationAuthFetcher(networkManager: mockNetworkManager) let operation = ContactDiscoveryV2Operation( db: InMemoryDB(), e164sToLookup: [try XCTUnwrap(E164("+16505550100"))], persistentState: persistentState, udManager: OWSMockUDManager(), connectionImpl: connection, - remoteAttestation: MockRemoteAttestation(), + remoteAttestationAuthFetcher: remoteAttestationAuthFetcher, ) // Establish the initial state. @@ -174,13 +189,15 @@ final class ContactDiscoveryV2OperationTest: XCTestCase { /// If the server reports an invalid token, we should clear the token. func testInvalidTokenError() async throws { let connection = MockContactDiscoveryConnection() + mockNetworkManager.asyncRequestHandlers.append(authFetchSuccessResponse) + let remoteAttestationAuthFetcher = RemoteAttestationAuthFetcher(networkManager: mockNetworkManager) let operation = ContactDiscoveryV2Operation( db: InMemoryDB(), e164sToLookup: [try XCTUnwrap(E164("+16505550100"))], persistentState: persistentState, udManager: OWSMockUDManager(), connectionImpl: connection, - remoteAttestation: MockRemoteAttestation(), + remoteAttestationAuthFetcher: remoteAttestationAuthFetcher, ) // Establish the initial state. diff --git a/SignalServiceKit/tests/SecureValueRecovery/SVR2/SVR2ConcurrencyTests.swift b/SignalServiceKit/tests/SecureValueRecovery/SVR2/SVR2ConcurrencyTests.swift index ba75b4e55e..a09e354627 100644 --- a/SignalServiceKit/tests/SecureValueRecovery/SVR2/SVR2ConcurrencyTests.swift +++ b/SignalServiceKit/tests/SecureValueRecovery/SVR2/SVR2ConcurrencyTests.swift @@ -24,7 +24,7 @@ struct SVR2ConcurrencyTests { mockConnection = MockSgxWebsocketConnection() mockConnection.mockEnclave = TSConstants.shared.svr2Enclaves.first! - mockConnection.mockAuth = RemoteAttestation.Auth(username: "username", password: "password") + mockConnection.mockAuth = RemoteAttestationAuth(username: "username", password: "password") mockConnectionFactory = MockSgxWebsocketConnectionFactory() let accountKeyStore = AccountKeyStore( @@ -38,6 +38,7 @@ struct SVR2ConcurrencyTests { db: db, accountKeyStore: accountKeyStore, pinHasher: MockPinHasher(), + remoteAttestationAuthFetcher: RemoteAttestationAuthFetcher(networkManager: MockNetworkManager()), storageServiceManager: FakeStorageServiceManager(), svrLocalStorage: localStorage, tsConstants: TSConstants.shared, @@ -87,7 +88,7 @@ struct SVR2ConcurrencyTests { mockConnectionFactory.setOnConnectAndPerformHandshake { (_: SVR2WebsocketConfigurator) in let mockConnection = MockSgxWebsocketConnection() mockConnection.mockEnclave = TSConstants.shared.svr2Enclaves.first! - mockConnection.mockAuth = RemoteAttestation.Auth(username: "username", password: "password") + mockConnection.mockAuth = RemoteAttestationAuth(username: "username", password: "password") mockConnection.onSendRequestAndReadResponse = onSendRequestAndReadResponse return mockConnection } @@ -121,11 +122,11 @@ struct SVR2ConcurrencyTests { let firstMockConnection = MockSgxWebsocketConnection() firstMockConnection.mockEnclave = TSConstants.shared.svr2Enclaves.first! - firstMockConnection.mockAuth = RemoteAttestation.Auth(username: "username", password: "password") + firstMockConnection.mockAuth = RemoteAttestationAuth(username: "username", password: "password") let secondMockConnection = MockSgxWebsocketConnection() secondMockConnection.mockEnclave = TSConstants.shared.svr2Enclaves.first! - secondMockConnection.mockAuth = RemoteAttestation.Auth(username: "username2", password: "password2") + secondMockConnection.mockAuth = RemoteAttestationAuth(username: "username2", password: "password2") var numOpenedConnections = 0 mockConnectionFactory.setOnConnectAndPerformHandshake { (_: SVR2WebsocketConfigurator) in diff --git a/SignalServiceKit/tests/SecureValueRecovery/SVR2/SecureValueRecovery2Tests.swift b/SignalServiceKit/tests/SecureValueRecovery/SVR2/SecureValueRecovery2Tests.swift index e3a47b04e6..fdab03aacd 100644 --- a/SignalServiceKit/tests/SecureValueRecovery/SVR2/SecureValueRecovery2Tests.swift +++ b/SignalServiceKit/tests/SecureValueRecovery/SVR2/SecureValueRecovery2Tests.swift @@ -33,7 +33,7 @@ class SecureValueRecovery2Tests: XCTestCase { localStorage = SVRLocalStorage() let mockConnection = MockSgxWebsocketConnection() - mockConnection.mockAuth = RemoteAttestation.Auth(username: "username", password: "password") + mockConnection.mockAuth = RemoteAttestationAuth(username: "username", password: "password") self.mockConnection = mockConnection mockConnectionFactory = MockSgxWebsocketConnectionFactory() @@ -45,6 +45,7 @@ class SecureValueRecovery2Tests: XCTestCase { db: db, accountKeyStore: accountKeyStore, pinHasher: MockPinHasher(), + remoteAttestationAuthFetcher: RemoteAttestationAuthFetcher(networkManager: MockNetworkManager()), storageServiceManager: FakeStorageServiceManager(), svrLocalStorage: localStorage, tsConstants: mockTSConstants, @@ -55,7 +56,7 @@ class SecureValueRecovery2Tests: XCTestCase { @MainActor func testMigration() async throws { // Set up the connections to both the old and new enclaves. - let mockAuth = RemoteAttestation.Auth(username: "username", password: "password") + let mockAuth = RemoteAttestationAuth(username: "username", password: "password") let oldEnclave = MrEnclave("0000000000000000000000000000000000000000000000000000000000000000") let oldEnclaveConnection = MockSgxWebsocketConnection() @@ -185,7 +186,7 @@ class SecureValueRecovery2Tests: XCTestCase { @MainActor func testMigration_forgottenEnclave() async throws { // Set up the connections to both the old and new enclaves. - let mockAuth = RemoteAttestation.Auth(username: "username", password: "password") + let mockAuth = RemoteAttestationAuth(username: "username", password: "password") let oldEnclave = MrEnclave("0000000000000000000000000000000000000000000000000000000000000000") let oldEnclaveConnection = MockSgxWebsocketConnection()