Signal-iOS/SignalServiceKit/tests/Account/PreKeys/PreKeyTaskTests.swift
2025-12-30 11:34:05 -08:00

451 lines
16 KiB
Swift

//
// Copyright 2023 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//
import Foundation
import XCTest
@testable import SignalServiceKit
final class PreKeyTaskTests: SSKBaseTest {
private var mockTSAccountManager: MockTSAccountManager!
private var mockIdentityManager: MockIdentityManager!
private var mockIdentityKeyMismatchManager: PreKeyTaskManager.Mocks.IdentityKeyMismatchManager!
private var mockAPIClient: PreKeyTaskManager.Mocks.APIClient!
private var mockDateProvider: PreKeyTaskManager.Mocks.DateProvider!
private var mockDb: InMemoryDB!
private var taskManager: PreKeyTaskManager!
private var mockAciProtocolStore: SignalProtocolStore!
private var mockPniProtocolStore: SignalProtocolStore!
private var mockProtocolStoreManager: SignalProtocolStoreManager!
private var mockPreKeyStore: SignalServiceKit.PreKeyStore!
override func setUp() {
super.setUp()
let testContext = (CurrentAppContext() as! TestAppContext)
testContext.shouldProcessIncomingMessages = false
let recipientDbTable = RecipientDatabaseTable()
let recipientFetcher = RecipientFetcher(
recipientDatabaseTable: recipientDbTable,
searchableNameIndexer: MockSearchableNameIndexer(),
)
let recipientIdFinder = RecipientIdFinder(
recipientDatabaseTable: recipientDbTable,
recipientFetcher: recipientFetcher,
)
mockIdentityManager = .init(recipientIdFinder: recipientIdFinder)
mockTSAccountManager = .init()
mockIdentityKeyMismatchManager = .init()
mockAPIClient = .init()
mockDateProvider = .init()
mockDb = InMemoryDB()
let sessionStore = SignalServiceKit.SessionStore()
mockPreKeyStore = PreKeyStore()
mockAciProtocolStore = .build(
dateProvider: mockDateProvider.targetDate,
identity: .aci,
preKeyStore: mockPreKeyStore,
recipientIdFinder: recipientIdFinder,
sessionStore: sessionStore,
)
mockPniProtocolStore = .build(
dateProvider: mockDateProvider.targetDate,
identity: .pni,
preKeyStore: mockPreKeyStore,
recipientIdFinder: recipientIdFinder,
sessionStore: sessionStore,
)
mockProtocolStoreManager = SignalProtocolStoreManager(
aciProtocolStore: mockAciProtocolStore,
pniProtocolStore: mockPniProtocolStore,
preKeyStore: mockPreKeyStore,
sessionStore: sessionStore,
)
taskManager = PreKeyTaskManager(
apiClient: mockAPIClient,
dateProvider: mockDateProvider.targetDate,
db: mockDb,
identityKeyMismatchManager: mockIdentityKeyMismatchManager,
identityManager: mockIdentityManager,
messageProcessor: SSKEnvironment.shared.messageProcessorRef,
protocolStoreManager: mockProtocolStoreManager,
remoteConfigProvider: MockRemoteConfigProvider(),
tsAccountManager: mockTSAccountManager,
)
}
override func tearDown() {
mockAPIClient.setPreKeysResult.ensureUnset()
super.tearDown()
}
private func aciPreKeyCount() -> Int {
return mockDb.read { tx in
return try! mockPreKeyStore.aciStore.fetchCount(in: .oneTime, isOneTime: true, tx: tx)
}
}
private func aciSignedPreKeyCount() -> Int {
return mockDb.read { tx in
return try! mockPreKeyStore.aciStore.fetchCount(in: .signed, isOneTime: false, tx: tx)
}
}
private func aciKyberOneTimePreKeyCount() -> Int {
return mockDb.read { tx in
return try! mockPreKeyStore.aciStore.fetchCount(in: .kyber, isOneTime: true, tx: tx)
}
}
private func aciKyberLastResortPreKeyCount() -> Int {
return mockDb.read { tx in
return try! mockPreKeyStore.aciStore.fetchCount(in: .kyber, isOneTime: false, tx: tx)
}
}
//
//
// MARK: - Create PreKey Tests
//
//
func testCreateAll() async throws {
mockAPIClient.setPreKeysResult = .value(())
mockIdentityManager.identityKeyPairs[.aci] = ECKeyPair.generateKeyPair()
_ = try await taskManager.refresh(identity: .aci, targets: .all, force: true, auth: .implicit())
// Validate
XCTAssertEqual(mockAPIClient.preKeyRecords?.count, 100)
XCTAssertNotNil(mockAPIClient.signedPreKeyRecord)
XCTAssertNotNil(mockAPIClient.pqLastResortPreKeyRecord)
XCTAssertNotNil(mockAPIClient.pqPreKeyRecords)
XCTAssertEqual(mockAPIClient.pqPreKeyRecords?.count, 100)
}
func testCreateSignedPreKeyOnly() async throws {
mockAPIClient.setPreKeysResult = .value(())
mockIdentityManager.identityKeyPairs[.aci] = ECKeyPair.generateKeyPair()
// Pre-validate
XCTAssertEqual(aciPreKeyCount(), 0)
XCTAssertEqual(aciSignedPreKeyCount(), 0)
_ = try await taskManager.refresh(identity: .aci, targets: .signedPreKey, force: true, auth: .implicit())
// Validate
XCTAssertEqual(aciPreKeyCount(), 0)
XCTAssertEqual(aciSignedPreKeyCount(), 1)
XCTAssertNotNil(mockAPIClient.signedPreKeyRecord)
XCTAssertNil(mockAPIClient.preKeyRecords)
XCTAssertNil(mockAPIClient.pqPreKeyRecords)
XCTAssertNil(mockAPIClient.pqLastResortPreKeyRecord)
}
func testCreatePreKeyOnly() async throws {
mockAPIClient.setPreKeysResult = .value(())
mockIdentityManager.identityKeyPairs[.aci] = ECKeyPair.generateKeyPair()
// Pre-validate
XCTAssertEqual(aciPreKeyCount(), 0)
XCTAssertEqual(aciSignedPreKeyCount(), 0)
_ = try await taskManager.refresh(identity: .aci, targets: .oneTimePreKey, force: true, auth: .implicit())
// Validate
XCTAssertEqual(aciPreKeyCount(), 100)
XCTAssertEqual(aciSignedPreKeyCount(), 0)
XCTAssertEqual(mockAPIClient.preKeyRecords?.count, 100)
XCTAssertNil(mockAPIClient.signedPreKeyRecord)
XCTAssertNil(mockAPIClient.pqPreKeyRecords)
XCTAssertNil(mockAPIClient.pqLastResortPreKeyRecord)
XCTAssertNotNil(mockAPIClient.preKeyRecords)
}
// Test that the IdentityMananger keypair makes it through to the
// service client
func testMockPreKeyTaskCreateWithExistingIdentityKey() async throws {
mockIdentityManager.identityKeyPairs[.aci] = ECKeyPair.generateKeyPair()
mockAPIClient.setPreKeysResult = .value(())
XCTAssertEqual(aciPreKeyCount(), 0)
_ = try await taskManager.refresh(identity: .aci, targets: .all, force: true, auth: .implicit())
XCTAssertEqual(mockAPIClient.preKeyRecords?.count, 100)
XCTAssertNotNil(mockAPIClient.signedPreKeyRecord)
XCTAssertNotNil(mockAPIClient.pqPreKeyRecords)
XCTAssertEqual(mockAPIClient.pqPreKeyRecords?.count, 100)
XCTAssertNotNil(mockAPIClient.pqLastResortPreKeyRecord)
}
func testMockCreatePreKeyOnlyWithExisting() async throws {
mockIdentityManager.identityKeyPairs[.aci] = ECKeyPair.generateKeyPair()
mockAPIClient.setPreKeysResult = .value(())
mockAPIClient.currentPreKeyCount = 100
XCTAssertEqual(aciPreKeyCount(), 0)
_ = try await taskManager.refresh(identity: .aci, targets: .oneTimePreKey, force: true, auth: .implicit())
XCTAssertEqual(mockAPIClient.preKeyRecords?.count, 100)
XCTAssertEqual(aciSignedPreKeyCount(), 0)
XCTAssertNil(mockAPIClient.signedPreKeyRecord)
XCTAssertNil(mockAPIClient.pqPreKeyRecords)
XCTAssertNil(mockAPIClient.pqLastResortPreKeyRecord)
}
func testCreatePqKeysOnly() async throws {
mockAPIClient.setPreKeysResult = .value(())
mockIdentityManager.identityKeyPairs[.aci] = ECKeyPair.generateKeyPair()
// Pre-validate
XCTAssertEqual(aciKyberOneTimePreKeyCount(), 0)
XCTAssertEqual(aciKyberLastResortPreKeyCount(), 0)
_ = try await taskManager.refresh(
identity: .aci,
targets: [.lastResortPqPreKey, .oneTimePqPreKey],
force: true,
auth: .implicit(),
)
// Validate
XCTAssertEqual(aciKyberOneTimePreKeyCount(), 100)
XCTAssertEqual(aciKyberLastResortPreKeyCount(), 1)
XCTAssertNil(mockAPIClient.signedPreKeyRecord)
XCTAssertNil(mockAPIClient.preKeyRecords)
XCTAssertNotNil(mockAPIClient.pqPreKeyRecords)
XCTAssertNotNil(mockAPIClient.pqLastResortPreKeyRecord)
}
//
//
// MARK: - Refresh Tests
//
//
func testMockPreKeyTaskUpdate() async throws {
let aciKeyPair = ECKeyPair.generateKeyPair()
mockIdentityManager.identityKeyPairs[.aci] = ECKeyPair.generateKeyPair()
let originalSignedPreKey = SignedPreKeyStoreImpl.generateSignedPreKey(keyId: PreKeyId.random(), signedBy: aciKeyPair.keyPair.privateKey)
mockDb.write { tx in
mockAciProtocolStore.signedPreKeyStore.storeSignedPreKey(originalSignedPreKey, tx: tx)
}
mockAPIClient.setPreKeysResult = .value(())
mockAPIClient.currentPreKeyCount = 0
mockAPIClient.currentPqPreKeyCount = 0
mockDateProvider.currentDate = Date(timeIntervalSinceNow: PreKeyTaskManager.Constants.SignedPreKeyRotationTime + 1)
_ = try await taskManager.refresh(identity: .aci, targets: .all, auth: .implicit())
XCTAssertNotNil(mockAPIClient.signedPreKeyRecord)
XCTAssertEqual(aciSignedPreKeyCount(), 2)
}
func testMockPreKeyTaskNoUpdate() async throws {
mockIdentityManager.identityKeyPairs[.aci] = ECKeyPair.generateKeyPair()
mockAPIClient.setPreKeysResult = .value(())
let records = mockDb.write { tx in
let preKeyIds = mockAciProtocolStore.preKeyStore.allocatePreKeyIds(tx: tx)
let records = PreKeyStoreImpl.generatePreKeyRecords(forPreKeyIds: preKeyIds)
mockAciProtocolStore.preKeyStore.storePreKeyRecords(records, tx: tx)
return records
}
mockAPIClient.currentPreKeyCount = 50
mockAPIClient.currentPqPreKeyCount = 0
XCTAssertEqual(aciPreKeyCount(), 100)
_ = try await taskManager.refresh(identity: .aci, targets: .all, auth: .implicit())
XCTAssertEqual(aciPreKeyCount(), 100)
mockDb.read { tx in
for record in records {
XCTAssertNotNil(mockPreKeyStore.aciStore.fetchPreKey(in: .oneTime, for: record.id, tx: tx))
}
}
XCTAssertNil(mockAPIClient.preKeyRecords)
}
func testMockUpdateFailNoIdentity() async throws {
XCTAssertEqual(aciPreKeyCount(), 0)
mockAPIClient.currentPreKeyCount = 0
mockAPIClient.currentPqPreKeyCount = 0
do {
try await taskManager.refresh(identity: .aci, targets: .all, auth: .implicit())
XCTFail("Expected failure, but returned success")
} catch let error {
switch error {
case PreKeyTaskManager.Error.noIdentityKey:
XCTAssertNil(self.mockIdentityManager.identityKeyPairs[.aci])
default:
XCTFail("Unexpected error")
}
}
XCTAssertEqual(aciPreKeyCount(), 0)
XCTAssertNil(mockAPIClient.preKeyRecords)
}
func testMockUpdateSkipSignedPreKey() async throws {
mockIdentityManager.identityKeyPairs[.aci] = ECKeyPair.generateKeyPair()
mockAPIClient.setPreKeysResult = .value(())
mockAPIClient.currentPreKeyCount = 50
mockAPIClient.currentPqPreKeyCount = 0
XCTAssertEqual(aciPreKeyCount(), 0)
_ = try await taskManager.refresh(identity: .aci, targets: .all, auth: .implicit())
XCTAssertEqual(aciPreKeyCount(), 0)
XCTAssertNil(mockAPIClient.preKeyRecords)
}
//
//
// MARK: - Force Refresh Tests
//
//
func testRefreshNoUpdatesNeeded() async throws {
mockIdentityManager.identityKeyPairs[.aci] = ECKeyPair.generateKeyPair()
mockAPIClient.setPreKeysResult = .value(())
mockAPIClient.currentPreKeyCount = 100
mockAPIClient.currentPqPreKeyCount = 100
mockDb.write { tx in
mockAciProtocolStore.signedPreKeyStore.setLastSuccessfulRotationDate(
mockDateProvider.currentDate,
tx: tx,
)
}
XCTAssertEqual(aciPreKeyCount(), 0)
_ = try await taskManager.refresh(identity: .aci, targets: .all, auth: .implicit())
XCTAssertEqual(aciPreKeyCount(), 0)
XCTAssertEqual(aciSignedPreKeyCount(), 0)
XCTAssertNil(mockAPIClient.preKeyRecords)
XCTAssertNil(mockAPIClient.signedPreKeyRecord)
}
func testForceRefreshAll() async throws {
mockIdentityManager.identityKeyPairs[.aci] = ECKeyPair.generateKeyPair()
mockAPIClient.setPreKeysResult = .value(())
mockAPIClient.currentPreKeyCount = 100
mockDb.write { tx in
mockAciProtocolStore.signedPreKeyStore.setLastSuccessfulRotationDate(
mockDateProvider.currentDate,
tx: tx,
)
}
XCTAssertEqual(aciPreKeyCount(), 0)
_ = try await taskManager.refresh(identity: .aci, targets: .all, force: true, auth: .implicit())
XCTAssertEqual(aciPreKeyCount(), 100)
XCTAssertEqual(aciSignedPreKeyCount(), 1)
XCTAssertNotNil(mockAPIClient.preKeyRecords)
XCTAssertNotNil(mockAPIClient.signedPreKeyRecord)
}
func testForceRefreshOnlyPreKeys() async throws {
mockIdentityManager.identityKeyPairs[.aci] = ECKeyPair.generateKeyPair()
mockAPIClient.setPreKeysResult = .value(())
mockAPIClient.currentPreKeyCount = 100
XCTAssertEqual(aciPreKeyCount(), 0)
_ = try await taskManager.refresh(identity: .aci, targets: .oneTimePreKey, force: true, auth: .implicit())
XCTAssertEqual(aciPreKeyCount(), 100)
XCTAssertEqual(aciSignedPreKeyCount(), 0)
XCTAssertNotNil(mockAPIClient.preKeyRecords)
XCTAssertNil(mockAPIClient.signedPreKeyRecord)
}
//
// PNI
//
func test422WhileSettingKeysReportsSuspectedPniIdentityKeyIssue() async {
mockTSAccountManager.registrationStateMock = { .provisioned }
mockIdentityManager.identityKeyPairs[.pni] = ECKeyPair.generateKeyPair()
mockAPIClient.setPreKeysResult = .error(OWSHTTPError.serviceResponse(.init(
requestUrl: URL(string: "https://example.com")!,
responseStatus: 422,
responseHeaders: HttpHeaders(),
responseData: nil,
)))
var didValidateIdentityKey = false
mockIdentityKeyMismatchManager.validateIdentityKeyMock = { _ in
didValidateIdentityKey = true
}
_ = try? await taskManager.refresh(identity: .pni, targets: .all, force: true, auth: .implicit())
// Validate
XCTAssertTrue(didValidateIdentityKey)
}
//
// Test validation
//
func testSignedPreKeyExpired() async throws {
mockIdentityManager.identityKeyPairs[.aci] = ECKeyPair.generateKeyPair()
mockAPIClient.setPreKeysResult = .value(())
mockDateProvider.currentDate = Date().addingTimeInterval(PreKeyTaskManager.Constants.SignedPreKeyRotationTime + 1)
_ = try await taskManager.refresh(identity: .aci, targets: .signedPreKey, auth: .implicit())
XCTAssertEqual(aciSignedPreKeyCount(), 1)
XCTAssertNotNil(mockAPIClient.signedPreKeyRecord)
}
func testRefreshOnlyPreKeysBasedOnCount() async throws {
mockIdentityManager.identityKeyPairs[.aci] = ECKeyPair.generateKeyPair()
mockAPIClient.setPreKeysResult = .value(())
mockAPIClient.currentPreKeyCount = 9
mockAPIClient.currentPqPreKeyCount = 0
XCTAssertEqual(aciPreKeyCount(), 0)
_ = try await taskManager.refresh(identity: .aci, targets: .oneTimePreKey, auth: .implicit())
XCTAssertEqual(aciPreKeyCount(), 100)
XCTAssertEqual(aciSignedPreKeyCount(), 0)
XCTAssertNotNil(mockAPIClient.preKeyRecords)
XCTAssertNil(mockAPIClient.signedPreKeyRecord)
}
}