Use Cron for linked device PNI identity validation

This commit is contained in:
Max Radermacher 2026-04-13 12:03:26 -05:00 committed by GitHub
parent 899749939a
commit 3b60ab1530
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 81 additions and 100 deletions

View File

@ -845,6 +845,7 @@
50D5E2412980AD6F00899660 /* LinkValidator.swift in Sources */ = {isa = PBXBuildFile; fileRef = 50D5E2402980AD6F00899660 /* LinkValidator.swift */; };
50D5E2432980B53000899660 /* LinkValidatorTest.swift in Sources */ = {isa = PBXBuildFile; fileRef = 50D5E2422980B53000899660 /* LinkValidatorTest.swift */; };
50D6A93F2AA9167400B7F093 /* UniqueObjectRecipientMerger.swift in Sources */ = {isa = PBXBuildFile; fileRef = 50D6A93E2AA9167400B7F093 /* UniqueObjectRecipientMerger.swift */; };
50D6BDEF2ED6724600CC012E /* DeviceType.swift in Sources */ = {isa = PBXBuildFile; fileRef = 50D6BDEE2ED6724600CC012E /* DeviceType.swift */; };
50D8796A2A16D2C20031345D /* MessageLoaderBatchTest.swift in Sources */ = {isa = PBXBuildFile; fileRef = 50D879692A16D2C20031345D /* MessageLoaderBatchTest.swift */; };
50D9CD8D2C52D78000273D6C /* StoryRecipientManager.swift in Sources */ = {isa = PBXBuildFile; fileRef = 50D9CD8C2C52D78000273D6C /* StoryRecipientManager.swift */; };
50DCCBFA2F1817280024D124 /* DisappearingMessagesConfigurationMessage.swift in Sources */ = {isa = PBXBuildFile; fileRef = 50DCCBF92F1817280024D124 /* DisappearingMessagesConfigurationMessage.swift */; };
@ -5113,6 +5114,7 @@
50D5E2402980AD6F00899660 /* LinkValidator.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LinkValidator.swift; sourceTree = "<group>"; };
50D5E2422980B53000899660 /* LinkValidatorTest.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LinkValidatorTest.swift; sourceTree = "<group>"; };
50D6A93E2AA9167400B7F093 /* UniqueObjectRecipientMerger.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = UniqueObjectRecipientMerger.swift; sourceTree = "<group>"; };
50D6BDEE2ED6724600CC012E /* DeviceType.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = DeviceType.swift; sourceTree = "<group>"; };
50D879692A16D2C20031345D /* MessageLoaderBatchTest.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MessageLoaderBatchTest.swift; sourceTree = "<group>"; };
50D9CD8C2C52D78000273D6C /* StoryRecipientManager.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = StoryRecipientManager.swift; sourceTree = "<group>"; };
50DCCBF92F1817280024D124 /* DisappearingMessagesConfigurationMessage.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = DisappearingMessagesConfigurationMessage.swift; sourceTree = "<group>"; };
@ -15074,6 +15076,7 @@
6659A0242A7C112700066AB7 /* PreKeys */,
661170BF2ABA458800A1B16D /* TSAccountManager */,
50F401CB2D483BF40094CA56 /* DeviceId.swift */,
50D6BDEE2ED6724600CC012E /* DeviceType.swift */,
D9F399AC2A95798A001599EC /* IdentityKeyChecker.swift */,
D9F399B12A96D65D001599EC /* IdentityKeyMismatchManager.swift */,
5033D46629D76BD0007FEADA /* LocalIdentifiers.swift */,
@ -19255,6 +19258,7 @@
72C9058C2B9AC7BD00E586B8 /* DeviceSleepManager.swift in Sources */,
F9C5CC94289453B300548EEE /* DeviceTransfer.pb.swift in Sources */,
F9C5CCA0289453B300548EEE /* DeviceTransferProto.swift in Sources */,
50D6BDEF2ED6724600CC012E /* DeviceType.swift in Sources */,
502D69322A7AC07C0085B656 /* Dictionary+SSK.swift in Sources */,
50DCCBFA2F1817280024D124 /* DisappearingMessagesConfigurationMessage.swift in Sources */,
F9C5CCDA289453B300548EEE /* DisappearingMessagesConfigurationRecord.swift in Sources */,

View File

@ -790,12 +790,7 @@ final class AppDelegate: UIResponder, UIApplicationDelegate {
cron.scheduleFrequently(
mustBeRegistered: true,
mustBeConnected: true,
operation: {
try await blockingManager.syncBlockListIfNecessary(force: false)
},
handleResult: { _ in
// Handled internally by BlockingManager.
},
operation: { try await blockingManager.syncBlockListIfNecessary(force: false) },
)
// Warm the "available emoji" cache, intentionally off the main thread.

View File

@ -184,6 +184,14 @@ public class AppEnvironment: NSObject {
operation: { try await subscriptionConfigManager.refresh() },
)
let identityKeyMismatchManager = DependenciesBridge.shared.identityKeyMismatchManager
cron.scheduleFrequently(
mustBeRegistered: true,
mustBeDeviceType: .linked,
mustBeConnected: true,
operation: { try await identityKeyMismatchManager.validateLocalPniIdentityKeyIfNecessary() },
)
appReadiness.runNowOrWhenAppWillBecomeReady {
self.badgeManager.startObservingChanges(in: DependenciesBridge.shared.databaseChangeObserver)
self.appIconBadgeUpdater.startObserving()
@ -200,7 +208,6 @@ public class AppEnvironment: NSObject {
let callRecordQuerier = DependenciesBridge.shared.callRecordQuerier
let db = DependenciesBridge.shared.db
let groupCallPeekClient = SSKEnvironment.shared.groupCallManagerRef.groupCallPeekClient
let identityKeyMismatchManager = DependenciesBridge.shared.identityKeyMismatchManager
let interactionStore = DependenciesBridge.shared.interactionStore
let masterKeySyncManager = DependenciesBridge.shared.masterKeySyncManager
let notificationPresenter = SSKEnvironment.shared.notificationPresenterRef
@ -276,9 +283,6 @@ public class AppEnvironment: NSObject {
registeredState: registeredState,
)
} else {
Task {
await identityKeyMismatchManager.validateLocalPniIdentityKeyIfNecessary()
}
}
Task {

View File

@ -0,0 +1,11 @@
//
// Copyright 2025 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//
import Foundation
public enum DeviceType {
case primary
case linked
}

View File

@ -37,9 +37,9 @@ public protocol IdentityKeyMismatchManager {
///
/// We do not expect many devices to have ended up in a bad state, and so we
/// hope that this unlinking will be a rare last resort.
func validateLocalPniIdentityKeyIfNecessary() async
func validateLocalPniIdentityKeyIfNecessary() async throws
func validateIdentityKey(for identity: OWSIdentity) async
func validateIdentityKey(for identity: OWSIdentity) async throws
}
class IdentityKeyMismatchManagerImpl: IdentityKeyMismatchManager {
@ -58,8 +58,6 @@ class IdentityKeyMismatchManagerImpl: IdentityKeyMismatchManager {
private let tsAccountManager: TSAccountManager
private let whoAmIManager: any WhoAmIManager
private let isValidating = AtomicBool(false, lock: .init())
init(
db: any DB,
identityKeyChecker: IdentityKeyChecker,
@ -89,26 +87,16 @@ class IdentityKeyMismatchManagerImpl: IdentityKeyMismatchManager {
)
}
func validateLocalPniIdentityKeyIfNecessary() async {
let logger = logger
private let taskQueue = ConcurrentTaskQueue(concurrentLimit: 1)
guard isValidating.tryToSetFlag() else {
logger.warn("Skipping validation - already in flight!")
return
}
defer {
self.isValidating.set(false)
func validateLocalPniIdentityKeyIfNecessary() async throws {
try await taskQueue.run {
try await _validateLocalPniIdentityKeyIfNecessary()
}
}
guard tsAccountManager.registrationStateWithMaybeSneakyTransaction.isPrimaryDevice == false else {
return
}
do throws(CancellationError) {
try await self.messageProcessor.waitForFetchingAndProcessing()
} catch {
return
}
private func _validateLocalPniIdentityKeyIfNecessary() async throws {
try await self.messageProcessor.waitForFetchingAndProcessing()
let hasSuspectedIssue = self.db.read { tx in
return self.kvStore.getBool(
@ -121,10 +109,10 @@ class IdentityKeyMismatchManagerImpl: IdentityKeyMismatchManager {
return
}
await validateIdentityKey(for: .pni)
try await validateIdentityKey(for: .pni)
}
func validateIdentityKey(for identity: OWSIdentity) async {
func validateIdentityKey(for identity: OWSIdentity) async throws {
let logger = logger
logger.info("Validating identity key for \(identity)")
do {
@ -139,9 +127,8 @@ class IdentityKeyMismatchManagerImpl: IdentityKeyMismatchManager {
}
}
} catch {
// Eat all the errors -- the caller should be triggering this in response
// to its own error, and we always want to pass that error to the caller.
logger.warn("Couldn't validate identity key: \(error)")
throw error
}
}

View File

@ -503,8 +503,9 @@ struct PreKeyTaskManager {
}
case let .failure(error) where error.httpStatusCode == 422:
// We think we might have an incorrect identity key -- check it and
// deregister if it's wrong.
await self.identityKeyMismatchManager.validateIdentityKey(for: identity)
// deregister if it's wrong. We always eat this error because we want the
// caller to see the `error` from `uploadResult`.
try? await self.identityKeyMismatchManager.validateIdentityKey(for: identity)
fallthrough
case let .failure(error):
logger.info("[\(identity)] Failed to upload prekeys")

View File

@ -87,7 +87,7 @@ extension TSRegistrationState {
}
}
public var isPrimaryDevice: Bool? {
public var deviceType: DeviceType? {
switch self {
case .unregistered, .transferringIncoming:
// We don't yet know if this will be a primary
@ -97,9 +97,17 @@ extension TSRegistrationState {
// Irrelevant what this was, return nil.
return nil
case .registered, .deregistered, .reregistering, .transferringPrimaryOutgoing:
return true
return .primary
case .provisioned, .delinked, .relinking, .transferringLinkedOutgoing:
return false
return .linked
}
}
public var isPrimaryDevice: Bool? {
switch self.deviceType {
case .primary: true
case .linked: false
case .none: nil
}
}

View File

@ -128,6 +128,9 @@ public class Cron {
/// - Parameter mustBeRegistered: If true, `operation` won't be invoked
/// until the user is registered.
///
/// - Parameter mustBeDeviceType: If set, `operation` won't be invoked
/// unless the user is (or was) registered as the specific device type.
///
/// - Parameter mustBeConnected: If true, `operation` won't be invoked until
/// the user is connected.
///
@ -140,6 +143,7 @@ public class Cron {
uniqueKey: UniqueKey,
approximateInterval: TimeInterval,
mustBeRegistered: Bool,
mustBeDeviceType: DeviceType? = nil,
mustBeConnected: Bool,
isRetryable: @escaping (E) -> Bool = { $0.isRetryable },
operation: @escaping () async throws(E) -> Void,
@ -147,6 +151,7 @@ public class Cron {
let store = CronStore(uniqueKey: uniqueKey)
scheduleFrequently(
mustBeRegistered: mustBeRegistered,
mustBeDeviceType: mustBeDeviceType,
mustBeConnected: mustBeConnected,
maxAverageBackoff: approximateInterval,
isRetryable: isRetryable,
@ -162,10 +167,10 @@ public class Cron {
},
handleResult: { [db] result in
switch result {
case .failure(is NotRegisteredError), .success(false), .failure(is CancellationError):
// A requirement (e.g., mustBeRegistered) wasn't met, it's too early to run
// again, or we were canceled while running. Don't set any state so that we
// run again at the next opportunity.
case .success(false), .failure(is CancellationError):
// It's too early to run again or we were canceled while running/waiting to
// run (e.g., while waiting for a connection). Don't set any state so that
// we run again at the next opportunity.
break
case .success(true), .failure:
// We ran or hit a terminal error while trying to run; mark the job as
@ -208,6 +213,9 @@ public class Cron {
/// - Parameter mustBeRegistered: If true, `operation` won't be invoked
/// until the user is registered.
///
/// - Parameter mustBeDeviceType: If set, `operation` won't be invoked
/// unless the user is (or was) registered as the specific device type.
///
/// - Parameter mustBeConnected: If true, `operation` won't be invoked until
/// the user is connected.
///
@ -230,17 +238,19 @@ public class Cron {
/// network) and `NotRegisteredError`s that may be thrown.
public func scheduleFrequently<T, E>(
mustBeRegistered: Bool,
mustBeDeviceType: DeviceType? = nil,
mustBeConnected: Bool,
minAverageBackoff: TimeInterval = ExponentialBackoff.Defaults.minAverageBackoff,
maxAverageBackoff: TimeInterval = ExponentialBackoff.Defaults.maxAverageBackoff,
isRetryable: @escaping (E) -> Bool = { $0.isRetryable },
operation: @escaping () async throws(E) -> T,
handleResult: @escaping (Result<T, any Error>) async -> Void,
handleResult: @escaping (Result<T, any Error>) async -> Void = { _ in },
) {
self.jobs.update {
$0.append({ ctx async -> Void in
let attemptResult = await Self.runOuterOperationAttempt(
mustBeRegistered: mustBeRegistered,
mustBeDeviceType: mustBeDeviceType,
mustBeConnected: mustBeConnected,
minAverageBackoff: minAverageBackoff,
maxAverageBackoff: maxAverageBackoff,
@ -261,6 +271,7 @@ public class Cron {
/// throws a non-`isRetryable` error.
private static func runOuterOperationAttempt<T, E>(
mustBeRegistered: Bool,
mustBeDeviceType: DeviceType?,
mustBeConnected: Bool,
minAverageBackoff: TimeInterval,
maxAverageBackoff: TimeInterval,
@ -277,6 +288,7 @@ public class Cron {
block: { () throws(E) -> Result<T, any Error> in
return try await runInnerOperationAttempt(
mustBeRegistered: mustBeRegistered,
mustBeDeviceType: mustBeDeviceType,
mustBeConnected: mustBeConnected,
operation: operation,
ctx: ctx,
@ -299,6 +311,7 @@ public class Cron {
/// `operation` is invoked. All errors are immediately rethrown.
private static func runInnerOperationAttempt<T, E>(
mustBeRegistered: Bool,
mustBeDeviceType: DeviceType?,
mustBeConnected: Bool,
operation: () async throws(E) -> T,
ctx: CronContext,
@ -320,6 +333,9 @@ public class Cron {
if mustBeRegistered, !ctx.tsAccountManager.registrationStateWithMaybeSneakyTransaction.isRegistered {
return .failure(NotRegisteredError())
}
if let mustBeDeviceType, ctx.tsAccountManager.registrationStateWithMaybeSneakyTransaction.deviceType != mustBeDeviceType {
return .failure(OWSGenericError("must be \(mustBeDeviceType)"))
}
return .success(try await operation())
}

View File

@ -282,9 +282,6 @@ public final class KeyTransparencyManager {
try await prepareAndPerformSelfCheck(localIdentifiers: localIdentifiers)
},
handleResult: { _ in
// prepareAndPerformSelfCheck manages Cron state internally.
},
)
}

View File

@ -275,7 +275,7 @@ public class OWSMessageDecrypter {
let identityKeyMismatchManager = DependenciesBridge.shared.identityKeyMismatchManager
identityKeyMismatchManager.recordSuspectedIssueWithPniIdentityKey(tx: transaction)
Task {
await identityKeyMismatchManager.validateLocalPniIdentityKeyIfNecessary()
try await identityKeyMismatchManager.validateLocalPniIdentityKeyIfNecessary()
}
errorMessage = .failedDecryption(

View File

@ -78,7 +78,7 @@ public class AccountAttributesUpdaterImpl: AccountAttributesUpdater {
},
handleResult: { result in
switch result {
case .failure(is NotRegisteredError), .success(false), .failure(is CancellationError):
case .success(false), .failure(is CancellationError):
break
case .success(true):
// Handled by updateAccountAttributes.

View File

@ -68,16 +68,17 @@ final class IdentityKeyMismatchManagerTest: XCTestCase {
identityKeyMismatchManager.recordSuspectedIssueWithPniIdentityKey(tx: tx)
}
}
await identityKeyMismatchManager.validateLocalPniIdentityKeyIfNecessary()
try? await identityKeyMismatchManager.validateLocalPniIdentityKeyIfNecessary()
}
func testDoesntRecordIfPrimaryDevice() async {
messageProcessorMock.waitForFetchingAndProcessingMock = {}
tsAccountManagerMock.registrationStateMock = { .registered }
await db.awaitableWrite { tx in
return identityKeyMismatchManager.recordSuspectedIssueWithPniIdentityKey(tx: tx)
}
await identityKeyMismatchManager.validateLocalPniIdentityKeyIfNecessary()
try? await identityKeyMismatchManager.validateLocalPniIdentityKeyIfNecessary()
XCTAssertFalse(kvStore.hasDecryptionError())
}
@ -140,6 +141,7 @@ final class IdentityKeyMismatchManagerTest: XCTestCase {
}
func testEarlyExitIfPrimary() async {
messageProcessorMock.waitForFetchingAndProcessingMock = {}
tsAccountManagerMock.registrationStateMock = { .registered }
// This will fail if it doesn't early-exit, due to missing mocks.
@ -181,50 +183,6 @@ final class IdentityKeyMismatchManagerTest: XCTestCase {
XCTAssertTrue(self.isMarkedDeregistered)
XCTAssertEqual(serverHasSameKeyResponses, [])
}
/// Checks that multiple overlapping validation attempts are collapsed into
/// one. Also checks that a subsequent validation runs.
func testMultipleCallsResultInOneRun() async {
let fetchingAndProcessing = CancellableContinuation<Void>()
messageProcessorMock.waitForFetchingAndProcessingMock = { try! await fetchingAndProcessing.wait() }
let localIdentifiers = LocalIdentifiers.mock
whoAmIManagerMock.whoAmIResponse = .value(.forUnitTest(localIdentifiers: localIdentifiers))
tsAccountManagerMock.localIdentifiersMock = { localIdentifiers }
var serverHasSameKeyResponses = [true]
identityKeyCheckerMock.serverHasSameKeyAsLocalMock = { _, _ in serverHasSameKeyResponses.popFirst()! }
await db.awaitableWrite { tx in
identityKeyMismatchManager.recordSuspectedIssueWithPniIdentityKey(tx: tx)
}
await withTaskGroup(of: Void.self) { taskGroup in
taskGroup.addTask {
await self.identityKeyMismatchManager.validateLocalPniIdentityKeyIfNecessary()
}
taskGroup.addTask {
await self.identityKeyMismatchManager.validateLocalPniIdentityKeyIfNecessary()
}
// One of the two Tasks should be able to complete immediately.
_ = await taskGroup.next()
// Once it does, we can let the other one complete as well.
fetchingAndProcessing.resume(with: .success(()))
_ = await taskGroup.next()
}
XCTAssertFalse(kvStore.hasDecryptionError())
XCTAssertFalse(self.isMarkedDeregistered)
XCTAssertEqual(serverHasSameKeyResponses, [])
messageProcessorMock.waitForFetchingAndProcessingMock = {}
whoAmIManagerMock.whoAmIResponse = .value(.forUnitTest(localIdentifiers: localIdentifiers))
tsAccountManagerMock.localIdentifiersMock = { localIdentifiers }
serverHasSameKeyResponses = [false]
await runRunRun(recordIssue: true)
XCTAssertFalse(kvStore.hasDecryptionError())
XCTAssertTrue(self.isMarkedDeregistered)
XCTAssertEqual(serverHasSameKeyResponses, [])
}
}
// MARK: - Mocks

View File

@ -26,12 +26,12 @@ class _PreKeyTaskManager_IdentityKeyMismatchManagerMock: IdentityKeyMismatchMana
func recordSuspectedIssueWithPniIdentityKey(tx: DBWriteTransaction) {
}
func validateLocalPniIdentityKeyIfNecessary() async {
func validateLocalPniIdentityKeyIfNecessary() async throws {
}
var validateIdentityKeyMock: ((_ identity: OWSIdentity) async -> Void)!
func validateIdentityKey(for identity: OWSIdentity) async {
await validateIdentityKeyMock!(identity)
var validateIdentityKeyMock: ((_ identity: OWSIdentity) async throws -> Void)!
func validateIdentityKey(for identity: OWSIdentity) async throws {
try await validateIdentityKeyMock!(identity)
}
}