diff --git a/Signal.xcodeproj/project.pbxproj b/Signal.xcodeproj/project.pbxproj index 69a9ae9733..0950c5a243 100644 --- a/Signal.xcodeproj/project.pbxproj +++ b/Signal.xcodeproj/project.pbxproj @@ -3180,6 +3180,7 @@ D9AE0AD929187F850063488B /* MessageSenderJobRecord.swift in Sources */ = {isa = PBXBuildFile; fileRef = D9AE0AD829187F850063488B /* MessageSenderJobRecord.swift */; }; D9AE0ADD2918B2960063488B /* JobRecord+Columns.swift in Sources */ = {isa = PBXBuildFile; fileRef = D9AE0ADC2918B2960063488B /* JobRecord+Columns.swift */; }; D9B0AC7429EF42960070F31C /* TSInfoMessage+GroupUpdates+DisplayableGroupUpdateItem.swift in Sources */ = {isa = PBXBuildFile; fileRef = D9B0AC7329EF42960070F31C /* TSInfoMessage+GroupUpdates+DisplayableGroupUpdateItem.swift */; }; + D9B1A8BF2FB7B69200CE5FD3 /* FailIfThrowsRecordCursor.swift in Sources */ = {isa = PBXBuildFile; fileRef = D9B1A8BE2FB7B68C00CE5FD3 /* FailIfThrowsRecordCursor.swift */; }; D9B2E1182E748E1900A823E4 /* OWSByteCountFormatStyle.swift in Sources */ = {isa = PBXBuildFile; fileRef = D9B2E1172E748DFB00A823E4 /* OWSByteCountFormatStyle.swift */; }; D9B8541229137C150058F97B /* JobRecord.swift in Sources */ = {isa = PBXBuildFile; fileRef = D9B8541129137C150058F97B /* JobRecord.swift */; }; D9B95A9629E6830B00D7CB95 /* JobRecordTest.swift in Sources */ = {isa = PBXBuildFile; fileRef = D9B95A9429E682E900D7CB95 /* JobRecordTest.swift */; }; @@ -7473,6 +7474,7 @@ D9AE0AD829187F850063488B /* MessageSenderJobRecord.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MessageSenderJobRecord.swift; sourceTree = ""; }; D9AE0ADC2918B2960063488B /* JobRecord+Columns.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = "JobRecord+Columns.swift"; sourceTree = ""; }; D9B0AC7329EF42960070F31C /* TSInfoMessage+GroupUpdates+DisplayableGroupUpdateItem.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = "TSInfoMessage+GroupUpdates+DisplayableGroupUpdateItem.swift"; sourceTree = ""; }; + D9B1A8BE2FB7B68C00CE5FD3 /* FailIfThrowsRecordCursor.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FailIfThrowsRecordCursor.swift; sourceTree = ""; }; D9B2E1172E748DFB00A823E4 /* OWSByteCountFormatStyle.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OWSByteCountFormatStyle.swift; sourceTree = ""; }; D9B8541129137C150058F97B /* JobRecord.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = JobRecord.swift; sourceTree = ""; }; D9B91D8D2B17E2A600BCB11A /* GroupCallRecordRingUpdateDelegate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GroupCallRecordRingUpdateDelegate.swift; sourceTree = ""; }; @@ -15238,6 +15240,7 @@ F9B652C228D8E3DF006914CA /* DatabaseRecovery.swift */, D9FF515B2F03A2A10011982F /* DBUInt64.swift */, F9C5CA48289453B100548EEE /* DeepCopy.swift */, + D9B1A8BE2FB7B68C00CE5FD3 /* FailIfThrowsRecordCursor.swift */, F9C5CA40289453B100548EEE /* GRDBDatabaseStorageAdapter.swift */, F9C5CA47289453B100548EEE /* GRDBSchemaMigrator.swift */, D9B95A9929E8918200D7CB95 /* InMemoryDB.swift */, @@ -19315,6 +19318,7 @@ F9C5CE57289453B400548EEE /* Factories.swift in Sources */, F9C5CC1D289453B300548EEE /* FailedMessagesJob.swift in Sources */, 7255A4C82B98DF3E00E95368 /* FailedStorySendDisplayController.swift in Sources */, + D9B1A8BF2FB7B69200CE5FD3 /* FailIfThrowsRecordCursor.swift in Sources */, F9C5CE60289453B400548EEE /* FakeContactsManager.swift in Sources */, F94BFA9528EBB0D800A5F34E /* FakeMessageSender.swift in Sources */, F9C5CE54289453B400548EEE /* FakeStorageServiceManager.swift in Sources */, diff --git a/Signal/AppLaunch/AppEnvironment.swift b/Signal/AppLaunch/AppEnvironment.swift index d4de3225c1..8051f97f10 100644 --- a/Signal/AppLaunch/AppEnvironment.swift +++ b/Signal/AppLaunch/AppEnvironment.swift @@ -284,11 +284,7 @@ public class AppEnvironment: NSObject { // Things that should run on either the primary or linked devices. if let registeredState, registeredState.isPrimary { Task { - do { - try await avatarDefaultColorStorageServiceMigrator.performMigrationIfNecessary() - } catch { - Logger.warn("Couldn't perform avatar default color migration: \(error)") - } + await avatarDefaultColorStorageServiceMigrator.performMigrationIfNecessary() } Task { diff --git a/Signal/Avatars/AvatarDefaultColorStorageServiceMigrator.swift b/Signal/Avatars/AvatarDefaultColorStorageServiceMigrator.swift index 608622d455..98424474d1 100644 --- a/Signal/Avatars/AvatarDefaultColorStorageServiceMigrator.swift +++ b/Signal/Avatars/AvatarDefaultColorStorageServiceMigrator.swift @@ -32,8 +32,8 @@ struct AvatarDefaultColorStorageServiceMigrator { self.threadStore = threadStore } - func performMigrationIfNecessary() async throws { - try await db.awaitableWrite { tx in + func performMigrationIfNecessary() async { + await db.awaitableWrite { tx in if kvStore.hasValue(StoreKeys.hasEnqueuedMigrationKey, transaction: tx) { return } @@ -46,15 +46,14 @@ struct AvatarDefaultColorStorageServiceMigrator { } var groupV2MasterKeys = [GroupMasterKey]() - try threadStore.enumerateGroupThreads(tx: tx) { groupThread in - guard + threadStore.enumerateGroupThreads(tx: tx) { groupThread in + if let groupModelV2 = groupThread.groupModel as? TSGroupModelV2, let groupMasterKey = try? groupModelV2.masterKey() - else { - return true + { + groupV2MasterKeys.append(groupMasterKey) } - groupV2MasterKeys.append(groupMasterKey) return true } diff --git a/Signal/src/ViewControllers/ThreadSettings/AddToGroupViewController.swift b/Signal/src/ViewControllers/ThreadSettings/AddToGroupViewController.swift index 15cf62f62b..d5a157264b 100644 --- a/Signal/src/ViewControllers/ThreadSettings/AddToGroupViewController.swift +++ b/Signal/src/ViewControllers/ThreadSettings/AddToGroupViewController.swift @@ -61,27 +61,23 @@ public class AddToGroupViewController: OWSTableViewController2 { return databaseStorage.read { transaction in var result = [TSGroupThread]() - do { - try ThreadFinder().enumerateGroupThreads(transaction: transaction) { thread -> Bool in - if thread.isGroupV2Thread { - let groupViewHelper = GroupViewHelper( - threadViewModel: ThreadViewModel( - thread: thread, - forChatList: false, - transaction: transaction, - ), - memberLabelCoordinator: nil, - ) + ThreadFinder().enumerateGroupThreads(tx: transaction) { groupThread -> Bool in + if groupThread.isGroupV2Thread { + let groupViewHelper = GroupViewHelper( + threadViewModel: ThreadViewModel( + thread: groupThread, + forChatList: false, + transaction: transaction, + ), + memberLabelCoordinator: nil, + ) - if groupViewHelper.canEditConversationMembership { - result.append(thread) - } + if groupViewHelper.canEditConversationMembership { + result.append(groupThread) } - - return true } - } catch { - owsFailDebug("Failed to fetch group threads: \(error). Returning an empty array") + + return true } return result diff --git a/SignalServiceKit/Backups/Archiving/Archivers/AdHocCall/BackupArchiveAdHocCallArchiver.swift b/SignalServiceKit/Backups/Archiving/Archivers/AdHocCall/BackupArchiveAdHocCallArchiver.swift index 4d3f30e696..7e25f7c338 100644 --- a/SignalServiceKit/Backups/Archiving/Archivers/AdHocCall/BackupArchiveAdHocCallArchiver.swift +++ b/SignalServiceKit/Backups/Archiving/Archivers/AdHocCall/BackupArchiveAdHocCallArchiver.swift @@ -67,61 +67,58 @@ public class BackupArchiveAdHocCallArchiver: BackupArchiveProtoStreamWriter { context: BackupArchive.ChatArchivingContext, ) throws(CancellationError) -> ArchiveMultiFrameResult { var partialErrors = [ArchiveFrameError]() - do { - try context.bencher.wrapEnumeration( - callRecordStore.enumerateAdHocCallRecords(tx:block:), - tx: context.tx, - ) { record, frameBencher in - try Task.checkCancellation() - autoreleasepool { - let callTimestamp = record.callBeganTimestamp - guard BackupArchive.Timestamps.isValid(callTimestamp) else { - partialErrors.append(.archiveFrameError(.invalidAdHocCallTimestamp)) - return - } - var adHocCallProto = BackupProto_AdHocCall() - adHocCallProto.callID = record.callId - adHocCallProto.callTimestamp = record.callBeganTimestamp - - // It's a cross-client decision that `state` can only - // ever be `.generic` (even if the client state is - // actually `.joined`). - adHocCallProto.state = .generic - - guard - let callLinkRecordId = BackupArchive.CallLinkRecordId(callRecordConversationId: record.conversationId) - else { - partialErrors.append(.archiveFrameError(.adHocCallDoesNotHaveCallLinkAsConversationId)) - return - } - guard let recipientId = context.recipientContext[.callLink(callLinkRecordId)] else { - partialErrors.append(.archiveFrameError( - .referencedRecipientIdMissing(.callLink(callLinkRecordId)), - )) - return - } - adHocCallProto.recipientID = recipientId.value - - let error: ArchiveFrameError? = Self.writeFrameToStream( - stream, - frameBencher: frameBencher, - ) { - var frame = BackupProto_Frame() - frame.adHocCall = adHocCallProto - return frame - } - - if let error { - partialErrors.append(error) - } + try context.bencher.wrapEnumeration( + tx: context.tx, + enumerationBlock: { tx, block throws(CancellationError) in + try callRecordStore.enumerateAdHocCallRecords(tx: tx, block: block) + }, + perEnumerantBlock: { record, frameBencher -> Bool in + let callTimestamp = record.callBeganTimestamp + guard BackupArchive.Timestamps.isValid(callTimestamp) else { + partialErrors.append(.archiveFrameError(.invalidAdHocCallTimestamp)) + return true } - } - } catch let error as CancellationError { - throw error - } catch { - return .completeFailure(.fatalArchiveError(.adHocCallIteratorError(error))) - } + + var adHocCallProto = BackupProto_AdHocCall() + adHocCallProto.callID = record.callId + adHocCallProto.callTimestamp = record.callBeganTimestamp + + // It's a cross-client decision that `state` can only + // ever be `.generic` (even if the client state is + // actually `.joined`). + adHocCallProto.state = .generic + + guard + let callLinkRecordId = BackupArchive.CallLinkRecordId(callRecordConversationId: record.conversationId) + else { + partialErrors.append(.archiveFrameError(.adHocCallDoesNotHaveCallLinkAsConversationId)) + return true + } + guard let recipientId = context.recipientContext[.callLink(callLinkRecordId)] else { + partialErrors.append(.archiveFrameError( + .referencedRecipientIdMissing(.callLink(callLinkRecordId)), + )) + return true + } + adHocCallProto.recipientID = recipientId.value + + let error: ArchiveFrameError? = Self.writeFrameToStream( + stream, + frameBencher: frameBencher, + ) { + var frame = BackupProto_Frame() + frame.adHocCall = adHocCallProto + return frame + } + + if let error { + partialErrors.append(error) + } + + return true + }, + ) if partialErrors.isEmpty { return .success diff --git a/SignalServiceKit/Backups/Archiving/Archivers/BackupArchive+Errors.swift b/SignalServiceKit/Backups/Archiving/Archivers/BackupArchive+Errors.swift index ee307c46eb..3e486d02a2 100644 --- a/SignalServiceKit/Backups/Archiving/Archivers/BackupArchive+Errors.swift +++ b/SignalServiceKit/Backups/Archiving/Archivers/BackupArchive+Errors.swift @@ -504,40 +504,10 @@ extension BackupArchive { /// single frame. public struct FatalArchivingError: BackupArchive.LoggableError { public enum ErrorType { - /// Error iterating over all SignalRecipients for backup purposes. - case recipientIteratorError(RawError) - - /// Error iterating over all threads for backup purposes. - case threadIteratorError(RawError) - /// We fetched a thread (via the iterator) with no sqlite row id. - case fetchedThreadMissingRowId - - /// Some unrecognized thread was found when iterating over all threads. - case unrecognizedThreadType - - /// Error iterating over all interactions for backup purposes. - case interactionIteratorError(RawError) - /// We fetched an interaction (via the iterator) with no sqlite row id. - case fetchedInteractionMissingRowId - - /// Error fetching reactions for a message. - case reactionIteratorError(RawError) - - /// Error iterating over all sticker packs for backup purposes. - case stickerPackIteratorError(RawError) - - /// Error iterating over all call link records for backup purposes. - case callLinkRecordIteratorError(RawError) - - /// Error iterating over all ad hoc calls for backup purposes. - case adHocCallIteratorError(RawError) - - case oversizedTextCacheFetchError(RawError) - - /// These should never happen; it means some invariant in the backup code - /// we could not enforce with the type system was broken. Nothing was wrong with - /// the proto or local database; its the iOS backup code that has a bug somewhere. - case developerError(OWSAssertionError) + /// An code-level invariant of some sort was violated in the Backups + /// archiving code; for example, some codepath found an object type + /// that should be handled elsewhere. + case developerError(message: String) } private let type: ErrorType diff --git a/SignalServiceKit/Backups/Archiving/Archivers/Chat/BackupArchiveChatArchiver.swift b/SignalServiceKit/Backups/Archiving/Archivers/Chat/BackupArchiveChatArchiver.swift index a1ccd23f35..e34e67bed5 100644 --- a/SignalServiceKit/Backups/Archiving/Archivers/Chat/BackupArchiveChatArchiver.swift +++ b/SignalServiceKit/Backups/Archiving/Archivers/Chat/BackupArchiveChatArchiver.swift @@ -50,74 +50,32 @@ public class BackupArchiveChatArchiver: BackupArchiveProtoStreamWriter { var completeFailureError: BackupArchive.FatalArchivingError? var partialErrors = [ArchiveFrameError]() - func archiveThread(_ thread: TSThread, _ frameBencher: BackupArchive.Bencher.FrameBencher) -> Bool { - var stop = false - autoreleasepool { - let result: ArchiveMultiFrameResult - if let thread = thread as? TSContactThread { - // Check address directly; isNoteToSelf uses global state. - if thread.contactAddress.isEqualToAddress(context.recipientContext.localIdentifiers.aciAddress) { - result = self.archiveNoteToSelfThread( - thread, - stream: stream, - frameBencher: frameBencher, - context: context, - ) - } else { - result = self.archiveContactThread( - thread, - stream: stream, - frameBencher: frameBencher, - context: context, - ) - } - } else if let thread = thread as? TSGroupThread, thread.isGroupV2Thread { - result = self.archiveGroupV2Thread( - thread, - stream: stream, - frameBencher: frameBencher, - context: context, - ) - } else if let thread = thread as? TSGroupThread, thread.isGroupV1Thread { - // Remember which threads were gv1 so we can silently drop their messages. - context.gv1ThreadIds.insert(thread.uniqueThreadIdentifier) - // Skip gv1 threads; count as success. - result = .success - } else if thread.isReleaseNotesThread { - // TODO: [KC] implement release notes in backups - result = .success - } else { - result = .completeFailure(.fatalArchiveError(.unrecognizedThreadType)) - } + try context.bencher.wrapEnumeration( + tx: context.tx, + enumerationBlock: { tx, block throws(CancellationError) in + try threadStore.enumerateNonStoryThreads(tx: tx, block: block) + }, + perEnumerantBlock: { [self] thread, frameBencher in + let result = archiveThread( + thread: thread, + stream: stream, + frameBencher: frameBencher, + context: context, + ) switch result { case .success: break case .completeFailure(let error): completeFailureError = error - stop = true - return + return false case .partialSuccess(let errors): partialErrors.append(contentsOf: errors) } - } - return !stop - } - - do { - try context.bencher.wrapEnumeration( - threadStore.enumerateNonStoryThreads(tx:block:), - tx: context.tx, - ) { thread, frameBencher in - try Task.checkCancellation() - return archiveThread(thread, frameBencher) - } - } catch let error as CancellationError { - throw error - } catch let error { - return .completeFailure(.fatalArchiveError(.threadIteratorError(error))) - } + return true + }, + ) if let completeFailureError { return .completeFailure(completeFailureError) @@ -128,18 +86,60 @@ public class BackupArchiveChatArchiver: BackupArchiveProtoStreamWriter { } } - private func archiveNoteToSelfThread( - _ thread: TSContactThread, + private func archiveThread( + thread: TSThread, stream: BackupArchiveProtoOutputStream, frameBencher: BackupArchive.Bencher.FrameBencher, context: BackupArchive.ChatArchivingContext, ) -> ArchiveMultiFrameResult { - guard let threadRowId = thread.sqliteRowId else { - return .completeFailure(.fatalArchiveError( - .fetchedThreadMissingRowId, - )) + if let thread = thread as? TSContactThread { + if thread.contactAddress.isEqualToAddress(context.recipientContext.localIdentifiers.aciAddress) { + return archiveNoteToSelfThread( + thread, + threadRowId: thread.sqliteRowId!, + stream: stream, + frameBencher: frameBencher, + context: context, + ) + } else { + return self.archiveContactThread( + thread, + threadRowId: thread.sqliteRowId!, + stream: stream, + frameBencher: frameBencher, + context: context, + ) + } + } else if let thread = thread as? TSGroupThread, thread.isGroupV2Thread { + return archiveGroupV2Thread( + thread, + threadRowId: thread.sqliteRowId!, + stream: stream, + frameBencher: frameBencher, + context: context, + ) + } else if let thread = thread as? TSGroupThread, thread.isGroupV1Thread { + // Remember which threads were gv1 so we can silently drop their messages. + context.gv1ThreadIds.insert(thread.uniqueThreadIdentifier) + // Skip gv1 threads; count as success. + return .success + } else if thread.isReleaseNotesThread { + // TODO: [KC] implement release notes in backups + return .success + } else { + return .completeFailure(.fatalArchiveError(.developerError( + message: "Unexpected thread type! \(type(of: thread))", + ))) } + } + private func archiveNoteToSelfThread( + _ thread: TSContactThread, + threadRowId: TSThread.RowId, + stream: BackupArchiveProtoOutputStream, + frameBencher: BackupArchive.Bencher.FrameBencher, + context: BackupArchive.ChatArchivingContext, + ) -> ArchiveMultiFrameResult { return archiveThread( BackupArchive.ChatThread(threadType: .contact(thread), threadRowId: threadRowId), recipientId: context.recipientContext.localRecipientId, @@ -151,6 +151,7 @@ public class BackupArchiveChatArchiver: BackupArchiveProtoStreamWriter { private func archiveContactThread( _ thread: TSContactThread, + threadRowId: TSThread.RowId, stream: BackupArchiveProtoOutputStream, frameBencher: BackupArchive.Bencher.FrameBencher, context: BackupArchive.ChatArchivingContext, @@ -185,12 +186,6 @@ public class BackupArchiveChatArchiver: BackupArchiveProtoStreamWriter { } } - guard let threadRowId = thread.sqliteRowId else { - return .completeFailure(.fatalArchiveError( - .fetchedThreadMissingRowId, - )) - } - return archiveThread( BackupArchive.ChatThread(threadType: .contact(thread), threadRowId: threadRowId), recipientId: recipientId, @@ -202,6 +197,7 @@ public class BackupArchiveChatArchiver: BackupArchiveProtoStreamWriter { private func archiveGroupV2Thread( _ thread: TSGroupThread, + threadRowId: TSThread.RowId, stream: BackupArchiveProtoOutputStream, frameBencher: BackupArchive.Bencher.FrameBencher, context: BackupArchive.ChatArchivingContext, @@ -213,12 +209,6 @@ public class BackupArchiveChatArchiver: BackupArchiveProtoStreamWriter { return .partialSuccess([.archiveFrameError(.referencedRecipientIdMissing(recipientAddress))]) } - guard let threadRowId = thread.sqliteRowId else { - return .completeFailure(.fatalArchiveError( - .fetchedThreadMissingRowId, - )) - } - return archiveThread( BackupArchive.ChatThread(threadType: .groupV2(thread), threadRowId: threadRowId), recipientId: recipientId, diff --git a/SignalServiceKit/Backups/Archiving/Archivers/Chat/BackupArchiveThreadStore.swift b/SignalServiceKit/Backups/Archiving/Archivers/Chat/BackupArchiveThreadStore.swift index f9163165d0..62e246d65a 100644 --- a/SignalServiceKit/Backups/Archiving/Archivers/Chat/BackupArchiveThreadStore.swift +++ b/SignalServiceKit/Backups/Archiving/Archivers/Chat/BackupArchiveThreadStore.swift @@ -17,22 +17,22 @@ public final class BackupArchiveThreadStore { func enumerateNonStoryThreads( tx: DBReadTransaction, - block: (TSThread) throws -> Bool, - ) throws { + block: (TSThread) throws(CancellationError) -> Bool, + ) throws(CancellationError) { try threadStore.enumerateNonStoryThreads(tx: tx, block: block) } func enumerateGroupThreads( tx: DBReadTransaction, - block: (TSGroupThread) throws -> Bool, - ) throws { + block: (TSGroupThread) throws(CancellationError) -> Bool, + ) throws(CancellationError) { try threadStore.enumerateGroupThreads(tx: tx, block: block) } func enumerateStoryThreads( tx: DBReadTransaction, - block: (TSPrivateStoryThread) throws -> Bool, - ) throws { + block: (TSPrivateStoryThread) throws(CancellationError) -> Bool, + ) throws(CancellationError) { try threadStore.enumerateStoryThreads(tx: tx, block: block) } diff --git a/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/BackupArchiveChatItemArchiver.swift b/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/BackupArchiveChatItemArchiver.swift index 69e9eaff79..1aa028d23a 100644 --- a/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/BackupArchiveChatItemArchiver.swift +++ b/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/BackupArchiveChatItemArchiver.swift @@ -136,11 +136,16 @@ public class BackupArchiveChatItemArchiver: BackupArchiveProtoStreamWriter { var completeFailureError: BackupArchive.FatalArchivingError? var partialFailures = [ArchiveFrameError]() - func archiveInteraction( - _ interactionRecord: InteractionRecord, - _ frameBencher: BackupArchive.Bencher.FrameBencher, - ) -> Bool { - return autoreleasepool { () -> Bool in + try context.bencher.wrapEnumeration( + tx: context.tx, + enumerationBlock: { tx, block throws(CancellationError) in + var cursor = FailIfThrowsRecordCursor { + try InteractionRecord.fetchCursor(tx.database) + } + + while let interactionRecord = cursor.next(), try block(interactionRecord) {} + }, + perEnumerantBlock: { [self] interactionRecord, frameBencher -> Bool in let interaction: TSInteraction do { interaction = try TSInteraction.fromRecord(interactionRecord) @@ -149,13 +154,12 @@ public class BackupArchiveChatItemArchiver: BackupArchiveProtoStreamWriter { return true } - let result = self.archiveInteraction( + switch archiveInteraction( interaction, stream: stream, frameBencher: frameBencher, context: context, - ) - switch result { + ) { case .success: return true case .partialSuccess(let errors): @@ -165,32 +169,8 @@ public class BackupArchiveChatItemArchiver: BackupArchiveProtoStreamWriter { completeFailureError = error return false } - } - } - - do { - try context.bencher.wrapEnumeration( - { tx, block in - let cursor = try InteractionRecord - .fetchCursor(tx.database) - - while - let interactionRecord = try cursor.next(), - try block(interactionRecord) - {} - }, - tx: context.tx, - ) { interactionRecord, frameBencher in - try Task.checkCancellation() - return archiveInteraction(interactionRecord, frameBencher) - } - } catch let error as CancellationError { - throw error - } catch let error { - // Errors thrown here are from the iterator's SQL query, - // not the individual interaction handler. - return .completeFailure(.fatalArchiveError(.interactionIteratorError(error))) - } + }, + ) if let completeFailureError { return .completeFailure(completeFailureError) diff --git a/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/BackupArchiveReactionArchiver.swift b/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/BackupArchiveReactionArchiver.swift index 9dd177c538..bf94cc5726 100644 --- a/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/BackupArchiveReactionArchiver.swift +++ b/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/BackupArchiveReactionArchiver.swift @@ -21,12 +21,7 @@ class BackupArchiveReactionArchiver: BackupArchiveProtoStreamWriter { _ message: TSMessage, context: BackupArchive.RecipientArchivingContext, ) -> BackupArchive.ArchiveInteractionResult<[BackupProto_Reaction]> { - let reactions: [OWSReaction] - do { - reactions = try reactionStore.allReactions(message: message, context: context) - } catch { - return .completeFailure(.fatalArchiveError(.reactionIteratorError(error))) - } + let reactions = reactionStore.allReactions(message: message, context: context) var errors = [ArchiveFrameError]() var reactionProtos = [BackupProto_Reaction]() @@ -81,64 +76,45 @@ class BackupArchiveReactionArchiver: BackupArchiveProtoStreamWriter { for reaction in reactions { let reactorAddress = context[reaction.authorRecipientId] - let insertResult: Result switch reactorAddress { case .localAddress: - insertResult = Result { - try reactionStore.createReaction( + reactionStore.createReaction( + uniqueMessageId: message.uniqueId, + emoji: reaction.emoji, + reactorAci: context.localIdentifiers.aci, + sentAtTimestamp: reaction.sentTimestamp, + sortOrder: reaction.sortOrder, + context: context, + ) + case .contact(let address): + if let aci = address.aci { + reactionStore.createReaction( uniqueMessageId: message.uniqueId, emoji: reaction.emoji, - reactorAci: context.localIdentifiers.aci, + reactorAci: aci, sentAtTimestamp: reaction.sentTimestamp, sortOrder: reaction.sortOrder, context: context, ) - } - case .contact(let address): - if let aci = address.aci { - insertResult = Result { - try reactionStore.createReaction( - uniqueMessageId: message.uniqueId, - emoji: reaction.emoji, - reactorAci: aci, - sentAtTimestamp: reaction.sentTimestamp, - sortOrder: reaction.sortOrder, - context: context, - ) - } } else if let e164 = address.e164 { - insertResult = Result { - try reactionStore.createLegacyReaction( - uniqueMessageId: message.uniqueId, - emoji: reaction.emoji, - reactorE164: e164, - sentAtTimestamp: reaction.sentTimestamp, - sortOrder: reaction.sortOrder, - context: context, - ) - } + reactionStore.createLegacyReaction( + uniqueMessageId: message.uniqueId, + emoji: reaction.emoji, + reactorE164: e164, + sentAtTimestamp: reaction.sentTimestamp, + sortOrder: reaction.sortOrder, + context: context, + ) } else { reactionErrors.append(.restoreFrameError(.invalidProtoData(.reactionNotFromAciOrE164))) - continue } case .group, .distributionList, .releaseNotesChannel, .callLink: // Referencing a group or distributionList as the author is invalid. reactionErrors.append(.restoreFrameError(.invalidProtoData(.reactionNotFromAciOrE164))) - continue case nil: reactionErrors.append(.restoreFrameError( .invalidProtoData(.recipientIdNotFound(reaction.authorRecipientId)), )) - continue - } - - switch insertResult { - case .success: - break - case .failure(let insertError): - reactionErrors.append( - .restoreFrameError(.databaseInsertionFailed(insertError)), - ) } } diff --git a/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/BackupArchiveReactionStore.swift b/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/BackupArchiveReactionStore.swift index a88117ee67..ed3223f2bc 100644 --- a/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/BackupArchiveReactionStore.swift +++ b/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/BackupArchiveReactionStore.swift @@ -16,14 +16,16 @@ public class BackupArchiveReactionStore { func allReactions( message: TSMessage, context: BackupArchive.RecipientArchivingContext, - ) throws -> [OWSReaction] { + ) -> [OWSReaction] { let sql = """ SELECT * FROM \(OWSReaction.databaseTableName) WHERE \(OWSReaction.columnName(.uniqueMessageId)) = ? ORDER BY \(OWSReaction.columnName(.id)) DESC """ - let statement = try context.tx.database.cachedStatement(sql: sql) - return try OWSReaction.fetchAll(statement, arguments: [message.uniqueId]) + return failIfThrows { + let statement = try context.tx.database.cachedStatement(sql: sql) + return try OWSReaction.fetchAll(statement, arguments: [message.uniqueId]) + } } // MARK: - Restoring @@ -35,15 +37,18 @@ public class BackupArchiveReactionStore { sentAtTimestamp: UInt64, sortOrder: UInt64, context: BackupArchive.RecipientRestoringContext, - ) throws { - let reaction = OWSReaction.fromRestoredBackup( + ) { + let reaction = OWSReaction( uniqueMessageId: uniqueMessageId, emoji: emoji, reactorAci: reactorAci, + reactorPhoneNumber: nil, sentAtTimestamp: sentAtTimestamp, sortOrder: sortOrder, ) - try reaction.insert(context.tx.database) + failIfThrows { + try reaction.insert(context.tx.database) + } } /// In the olden days before the introduction of Acis, reactions were sent by e164s. @@ -54,14 +59,17 @@ public class BackupArchiveReactionStore { sentAtTimestamp: UInt64, sortOrder: UInt64, context: BackupArchive.RecipientRestoringContext, - ) throws { - let reaction = OWSReaction.fromRestoredBackup( + ) { + let reaction = OWSReaction( uniqueMessageId: uniqueMessageId, emoji: emoji, - reactorE164: reactorE164, + reactorAci: nil, + reactorPhoneNumber: reactorE164.stringValue, sentAtTimestamp: sentAtTimestamp, sortOrder: sortOrder, ) - try reaction.insert(context.tx.database) + failIfThrows { + try reaction.insert(context.tx.database) + } } } diff --git a/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/BackupArchiveTSIncomingMessageArchiver.swift b/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/BackupArchiveTSIncomingMessageArchiver.swift index 81bea4719b..bf148fef11 100644 --- a/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/BackupArchiveTSIncomingMessageArchiver.swift +++ b/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/BackupArchiveTSIncomingMessageArchiver.swift @@ -104,6 +104,7 @@ extension BackupArchiveTSIncomingMessageArchiver: BackupArchive.TSMessageEditHis threadInfo: BackupArchive.ChatArchivingContext.CachedThreadInfo, context: BackupArchive.ChatArchivingContext, ) -> BackupArchive.ArchiveInteractionResult
{ + let incomingMessageRowId = incomingMessage.sqliteRowId! var partialErrors = [ArchiveFrameError]() guard @@ -183,14 +184,8 @@ extension BackupArchiveTSIncomingMessageArchiver: BackupArchive.TSMessageEditHis expireStartDate = nil } - guard let interactionRowId = incomingMessage.sqliteRowId else { - return .completeFailure(.fatalArchiveError( - .fetchedInteractionMissingRowId, - )) - } - let pinMessageDetails = pinnedMessageManager.pinMessageDetails( - interactionId: interactionRowId, + interactionId: incomingMessageRowId, tx: context.tx, ) diff --git a/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/BackupArchiveTSMessageContentsArchiver.swift b/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/BackupArchiveTSMessageContentsArchiver.swift index 1ea2dc74e7..c8928275ed 100644 --- a/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/BackupArchiveTSMessageContentsArchiver.swift +++ b/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/BackupArchiveTSMessageContentsArchiver.swift @@ -191,11 +191,7 @@ class BackupArchiveTSMessageContentsArchiver: BackupArchiveProtoStreamWriter { _ message: TSMessage, context: BackupArchive.ChatArchivingContext, ) -> ArchiveInteractionResult { - guard let messageRowId = message.sqliteRowId else { - return .completeFailure(.fatalArchiveError( - .fetchedInteractionMissingRowId, - )) - } + let messageRowId = message.sqliteRowId! let messageOwnedReferencedAttachments: MessageOwnedReferencedAttachments = { let referencedAttachments = attachmentStore.fetchReferencedAttachmentsOwnedByMessage( diff --git a/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/BackupArchiveTSOutgoingMessageArchiver.swift b/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/BackupArchiveTSOutgoingMessageArchiver.swift index a70c063b66..5c4b9bead5 100644 --- a/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/BackupArchiveTSOutgoingMessageArchiver.swift +++ b/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/BackupArchiveTSOutgoingMessageArchiver.swift @@ -144,6 +144,7 @@ extension BackupArchiveTSOutgoingMessageArchiver: BackupArchive.TSMessageEditHis threadInfo: BackupArchive.ChatArchivingContext.CachedThreadInfo, context: BackupArchive.ChatArchivingContext, ) -> BackupArchive.ArchiveInteractionResult
{ + let outgoingMessageRowId = outgoingMessage.sqliteRowId! var partialErrors = [ArchiveFrameError]() let wasAnySendSealedSender: Bool @@ -177,14 +178,8 @@ extension BackupArchiveTSOutgoingMessageArchiver: BackupArchive.TSMessageEditHis expireStartDate = nil } - guard let interactionRowId = outgoingMessage.sqliteRowId else { - return .completeFailure(.fatalArchiveError( - .fetchedInteractionMissingRowId, - )) - } - let pinMessageDetails = pinnedMessageManager.pinMessageDetails( - interactionId: interactionRowId, + interactionId: outgoingMessageRowId, tx: context.tx, ) diff --git a/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/BackupOversizeTextCache.swift b/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/BackupOversizeTextCache.swift index 72651e57f3..d0f4c76f91 100644 --- a/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/BackupOversizeTextCache.swift +++ b/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/BackupOversizeTextCache.swift @@ -186,15 +186,10 @@ class BackupArchiveInlinedOversizeTextArchiver { )) } - let oversizedText: String? - do { - oversizedText = try self.fetchInlineableOversizedText( - attachmentId: oversizeTextReferencedAttachment.attachment.id, - tx: context.tx, - ) - } catch { - return .completeFailure(.fatalArchiveError(.oversizedTextCacheFetchError(error))) - } + let oversizedText = self.fetchInlineableOversizedText( + attachmentId: oversizeTextReferencedAttachment.attachment.id, + tx: context.tx, + ) if let oversizedText { // If we had downloaded the attachment, we'd have an oversized text to inline. @@ -381,7 +376,7 @@ class BackupArchiveInlinedOversizeTextArchiver { // MARK: - Helpers - private func fetchInlineableOversizedText(attachmentId: Attachment.IDType, tx: DBReadTransaction) throws -> String? { + private func fetchInlineableOversizedText(attachmentId: Attachment.IDType, tx: DBReadTransaction) -> String? { return failIfThrows { try BackupOversizeTextCache .filter(Column(BackupOversizeTextCache.CodingKeys.attachmentRowId) == attachmentId) diff --git a/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/ChatUpdateMessages/BackupArchiveSimpleChatUpdateArchiver.swift b/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/ChatUpdateMessages/BackupArchiveSimpleChatUpdateArchiver.swift index 1ccc489304..a02254c0ed 100644 --- a/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/ChatUpdateMessages/BackupArchiveSimpleChatUpdateArchiver.swift +++ b/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/ChatUpdateMessages/BackupArchiveSimpleChatUpdateArchiver.swift @@ -74,9 +74,9 @@ final class BackupArchiveSimpleChatUpdateArchiver { .typeEndPoll, .typePinnedMessage: // Non-simple chat update types - return .completeFailure(.fatalArchiveError( - .developerError(OWSAssertionError("Unexpected info message type: \(infoMessage.messageType)")), - )) + return .completeFailure(.fatalArchiveError(.developerError( + message: "Unexpected info message type: \(infoMessage.messageType)", + ))) case .verificationStateChange: guard let verificationStateChangeMessage = infoMessage as? OWSVerificationStateChangeMessage else { return messageFailure(.verificationStateChangeNotExpectedSDSRecordType) diff --git a/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/ChatUpdateMessages/GroupUpdates/BackupArchiveGroupUpdateMessageArchiver.swift b/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/ChatUpdateMessages/GroupUpdates/BackupArchiveGroupUpdateMessageArchiver.swift index 692c7d4bbb..679a9e8433 100644 --- a/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/ChatUpdateMessages/GroupUpdates/BackupArchiveGroupUpdateMessageArchiver.swift +++ b/SignalServiceKit/Backups/Archiving/Archivers/ChatItem/ChatUpdateMessages/GroupUpdates/BackupArchiveGroupUpdateMessageArchiver.swift @@ -37,7 +37,7 @@ final class BackupArchiveGroupUpdateMessageArchiver { case .nonGroupUpdate: // Should be impossible. return .completeFailure(.fatalArchiveError(.developerError( - OWSAssertionError("Invalid interaction type"), + message: "Invalid interaction type", ))) case .legacyRawString: return .skippableInteraction(.skippableGroupUpdate(.legacyRawString)) diff --git a/SignalServiceKit/Backups/Archiving/Archivers/Recipient/BackupArchiveCallLinkRecipientArchiver.swift b/SignalServiceKit/Backups/Archiving/Archivers/Recipient/BackupArchiveCallLinkRecipientArchiver.swift index 1e14f6e400..33937eb8a4 100644 --- a/SignalServiceKit/Backups/Archiving/Archivers/Recipient/BackupArchiveCallLinkRecipientArchiver.swift +++ b/SignalServiceKit/Backups/Archiving/Archivers/Recipient/BackupArchiveCallLinkRecipientArchiver.swift @@ -46,75 +46,72 @@ public class BackupArchiveCallLinkRecipientArchiver: BackupArchiveProtoStreamWri context: BackupArchive.RecipientArchivingContext, ) throws(CancellationError) -> ArchiveMultiFrameResult { var errors = [ArchiveFrameError]() - do { - try context.bencher.wrapEnumeration( - callLinkStore.enumerateAll(tx:block:), - tx: context.tx, - ) { record, frameBencher in - try Task.checkCancellation() - autoreleasepool { - var callLink = BackupProto_CallLink() - callLink.rootKey = record.rootKey.bytes - if let adminPasskey = record.adminPasskey { - // If there is no adminPasskey on the record, then the - // local user is not the call admin, and we leave this - // field blank on the proto. - callLink.adminKey = adminPasskey - } - if let name = record.name { - // If the default name is being used, just leave the field blank. - callLink.name = name - } - callLink.restrictions = { () -> BackupProto_CallLink.Restrictions in - if let restrictions = record.restrictions { - switch restrictions { - case .none: return .none - case .adminApproval: return .adminApproval - case .unknown: return .unknown - } - } else { - return .unknown - } - }() - let callLinkRecordId = CallLinkRecordId(record) - let callLinkAppId: RecipientAppId = .callLink(callLinkRecordId) - // Lacking an expiration is a valid state. It can occur 1) if we hadn't - // yet fetched the expiration from the server at the time of backup, or - // 2) if someone deletes a call link before we're able to fetch the - // expiration. - BackupArchive.Timestamps.setTimestampIfValid( - from: record, - \.expirationMs, - on: &callLink, - \.expirationMs, - allowZero: true, - ) - - owsAssertDebug(record.revoked != true, "call links should be deleted, not revoked") - - let recipientId = context.assignRecipientId(to: callLinkAppId) - let maybeError: ArchiveFrameError? = Self.writeFrameToStream( - stream, - frameBencher: frameBencher, - ) { - var recipient = BackupProto_Recipient() - recipient.id = recipientId.value - recipient.destination = .callLink(callLink) - var frame = BackupProto_Frame() - frame.item = .recipient(recipient) - return frame - } - if let maybeError { - errors.append(maybeError) - } + try context.bencher.wrapEnumeration( + tx: context.tx, + enumerationBlock: { tx, block throws(CancellationError) in + try callLinkStore.enumerateAll(tx: tx, block: block) + }, + perEnumerantBlock: { record, frameBencher -> Bool in + var callLink = BackupProto_CallLink() + callLink.rootKey = record.rootKey.bytes + if let adminPasskey = record.adminPasskey { + // If there is no adminPasskey on the record, then the + // local user is not the call admin, and we leave this + // field blank on the proto. + callLink.adminKey = adminPasskey } - } - } catch let error as CancellationError { - throw error - } catch { - return .completeFailure(.fatalArchiveError(.callLinkRecordIteratorError(error))) - } + if let name = record.name { + // If the default name is being used, just leave the field blank. + callLink.name = name + } + callLink.restrictions = { () -> BackupProto_CallLink.Restrictions in + if let restrictions = record.restrictions { + switch restrictions { + case .none: return .none + case .adminApproval: return .adminApproval + case .unknown: return .unknown + } + } else { + return .unknown + } + }() + + let callLinkRecordId = CallLinkRecordId(record) + let callLinkAppId: RecipientAppId = .callLink(callLinkRecordId) + // Lacking an expiration is a valid state. It can occur 1) if we hadn't + // yet fetched the expiration from the server at the time of backup, or + // 2) if someone deletes a call link before we're able to fetch the + // expiration. + BackupArchive.Timestamps.setTimestampIfValid( + from: record, + \.expirationMs, + on: &callLink, + \.expirationMs, + allowZero: true, + ) + + owsAssertDebug(record.revoked != true, "call links should be deleted, not revoked") + + let recipientId = context.assignRecipientId(to: callLinkAppId) + let maybeError: ArchiveFrameError? = Self.writeFrameToStream( + stream, + frameBencher: frameBencher, + ) { + var recipient = BackupProto_Recipient() + recipient.id = recipientId.value + recipient.destination = .callLink(callLink) + var frame = BackupProto_Frame() + frame.item = .recipient(recipient) + return frame + } + if let maybeError { + errors.append(maybeError) + } + + return true + }, + ) if errors.isEmpty { return .success @@ -167,23 +164,19 @@ public class BackupArchiveCallLinkRecipientArchiver: BackupArchiveProtoStreamWri || callLinkProto.expirationMs != 0, ) - do { - let record = try callLinkStore.insertFromBackup( - rootKey: rootKey, - adminPasskey: adminKey, - name: hasAnyState ? callLinkProto.name.nilIfEmpty : nil, - restrictions: hasAnyState ? restrictions : nil, - revoked: hasAnyState ? false : nil, - expiration: hasAnyState ? Int64(callLinkProto.expirationMs / 1000) : nil, - isUpcoming: hasAnyState ? (adminKey != nil) : nil, - tx: context.tx, - ) - let callLinkRecordId = CallLinkRecordId(record) - context[recipient.recipientId] = .callLink(callLinkRecordId) - context[callLinkRecordId] = record - } catch { - return .failure([.restoreFrameError(.databaseInsertionFailed(error))]) - } + let record = callLinkStore.insertFromBackup( + rootKey: rootKey, + adminPasskey: adminKey, + name: hasAnyState ? callLinkProto.name.nilIfEmpty : nil, + restrictions: hasAnyState ? restrictions : nil, + revoked: hasAnyState ? false : nil, + expiration: hasAnyState ? Int64(callLinkProto.expirationMs / 1000) : nil, + isUpcoming: hasAnyState ? (adminKey != nil) : nil, + tx: context.tx, + ) + let callLinkRecordId = CallLinkRecordId(record) + context[recipient.recipientId] = .callLink(callLinkRecordId) + context[callLinkRecordId] = record return .success } diff --git a/SignalServiceKit/Backups/Archiving/Archivers/Recipient/BackupArchiveContactRecipientArchiver.swift b/SignalServiceKit/Backups/Archiving/Archivers/Recipient/BackupArchiveContactRecipientArchiver.swift index 9946bc6c0a..572903553e 100644 --- a/SignalServiceKit/Backups/Archiving/Archivers/Recipient/BackupArchiveContactRecipientArchiver.swift +++ b/SignalServiceKit/Backups/Archiving/Archivers/Recipient/BackupArchiveContactRecipientArchiver.swift @@ -122,7 +122,12 @@ public class BackupArchiveContactRecipientArchiver: BackupArchiveProtoStreamWrit /// key" for contacts. They directly contain many of the fields we store /// in a `Contact` recipient, with the other fields keyed off data in /// the recipient. - let recipientBlock: (SignalRecipient, BackupArchive.Bencher.FrameBencher) -> Void = { recipient, frameBencher in + try context.bencher.wrapEnumeration( + tx: context.tx, + enumerationBlock: { tx, block throws(CancellationError) in + try recipientStore.enumerateAllSignalRecipients(tx: tx, block: block) + }, + ) { recipient, frameBencher -> Bool in guard let contactAddress = BackupArchive.ContactAddress( aci: recipient.aci, @@ -132,7 +137,7 @@ public class BackupArchiveContactRecipientArchiver: BackupArchiveProtoStreamWrit else { /// Skip recipients with no identifiers, but don't add to the /// list of errors. - return + return true } guard @@ -143,7 +148,7 @@ public class BackupArchiveContactRecipientArchiver: BackupArchiveProtoStreamWrit ) else { // Skip the local user. - return + return true } /// Track the `ServiceId`s for this `SignalRecipient`, so we don't @@ -177,7 +182,7 @@ public class BackupArchiveContactRecipientArchiver: BackupArchiveProtoStreamWrit .fetchOne(context.tx.database) } catch let error { errors.append(.archiveFrameError(.unableToFetchRecipientIdentity(error))) - return + return true } let username: String? = recipient.aci @@ -265,21 +270,7 @@ public class BackupArchiveContactRecipientArchiver: BackupArchiveProtoStreamWrit ) writeToStream(contact: contact, contactAddress: contactAddress, contactDbRowId: recipient.id, frameBencher: frameBencher) - } - - do { - try context.bencher.wrapEnumeration( - recipientStore.enumerateAllSignalRecipients(tx:block:), - tx: context.tx, - ) { recipient, frameBencher in - autoreleasepool { - recipientBlock(recipient, frameBencher) - } - } - } catch let error as CancellationError { - throw error - } catch { - return .completeFailure(.fatalArchiveError(.recipientIteratorError(error))) + return true } /// After enumerating all `SignalRecipient`s, we enumerate @@ -302,18 +293,19 @@ public class BackupArchiveContactRecipientArchiver: BackupArchiveProtoStreamWrit /// like `OWSUserProfile`. If, in the future, we have an enforced 1:1 /// relationship between `SignalRecipient` and `OWSUserProfile`, we can /// remove this code. - context.bencher.wrapEnumeration( - profileManager.enumerateUserProfiles(tx:block:), + try context.bencher.wrapEnumeration( tx: context.tx, - ) { userProfile, frameBencher in - autoreleasepool { + enumerationBlock: { tx, block throws(CancellationError) in + try profileManager.enumerateUserProfiles(tx: tx, block: block) + }, + perEnumerantBlock: { userProfile, frameBencher -> Bool in if let serviceId = userProfile.serviceId { let (inserted, _) = archivedServiceIds.insert(serviceId) if !inserted { /// Bail early if we've already archived a `Contact` for this /// service ID. - return + return true } } if let phoneNumber = userProfile.phoneNumber { @@ -322,7 +314,7 @@ public class BackupArchiveContactRecipientArchiver: BackupArchiveProtoStreamWrit if !inserted { /// Bail early if we've already archived a `Contact` for this /// phone number. - return + return true } } @@ -335,7 +327,7 @@ public class BackupArchiveContactRecipientArchiver: BackupArchiveProtoStreamWrit else { /// Skip profiles with no identifiers, but don't add to the /// list of errors. - return + return true } let signalServiceAddress: BackupArchive.InteropAddress @@ -344,7 +336,7 @@ public class BackupArchiveContactRecipientArchiver: BackupArchiveProtoStreamWrit /// Skip the local user. We need to check `internalAddress` /// here, since the "local user profile" has historically been /// persisted with a special, magic phone number. - return + return true case .otherUser(let _signalServiceAddress): signalServiceAddress = _signalServiceAddress } @@ -385,8 +377,10 @@ public class BackupArchiveContactRecipientArchiver: BackupArchiveProtoStreamWrit contactDbRowId: nil, frameBencher: frameBencher, ) - } - } + + return true + }, + ) if errors.isEmpty { return .success diff --git a/SignalServiceKit/Backups/Archiving/Archivers/Recipient/BackupArchiveDistributionListRecipientArchiver.swift b/SignalServiceKit/Backups/Archiving/Archivers/Recipient/BackupArchiveDistributionListRecipientArchiver.swift index 16e7c423d3..0d70faa9d5 100644 --- a/SignalServiceKit/Backups/Archiving/Archivers/Recipient/BackupArchiveDistributionListRecipientArchiver.swift +++ b/SignalServiceKit/Backups/Archiving/Archivers/Recipient/BackupArchiveDistributionListRecipientArchiver.swift @@ -33,46 +33,41 @@ public class BackupArchiveDistributionListRecipientArchiver: BackupArchiveProtoS ) throws(CancellationError) -> ArchiveMultiFrameResult { var errors = [ArchiveFrameError]() - do { - // enumerate deleted threads - for item in privateStoryThreadDeletionManager.allDeletedIdentifiers(tx: context.tx) { - try Task.checkCancellation() - autoreleasepool { - context.bencher.processFrame { frameBencher in - self.archiveDeletedStoryList( - rawDistributionId: item, - stream: stream, - frameBencher: frameBencher, - context: context, - errors: &errors, - ) - } - } + for item in privateStoryThreadDeletionManager.allDeletedIdentifiers(tx: context.tx) { + if Task.isCancelled { + throw CancellationError() } - try context.bencher.wrapEnumeration( - threadStore.enumerateStoryThreads(tx:block:), - tx: context.tx, - ) { storyThread, frameBencher in - try Task.checkCancellation() - autoreleasepool { - self.archiveStoryThread( - storyThread, + + autoreleasepool { + context.bencher.processFrame { frameBencher in + archiveDeletedStoryList( + rawDistributionId: item, stream: stream, frameBencher: frameBencher, context: context, errors: &errors, ) } - - return true } - } catch let error as CancellationError { - throw error - } catch { - // The enumeration of threads failed, not the processing of one single thread. - return .completeFailure(.fatalArchiveError(.threadIteratorError(error))) } + try context.bencher.wrapEnumeration( + tx: context.tx, + enumerationBlock: { tx, block throws(CancellationError) in + try threadStore.enumerateStoryThreads(tx: tx, block: block) + }, + perEnumerantBlock: { [self] storyThread, frameBencher in + archiveStoryThread( + storyThread, + stream: stream, + frameBencher: frameBencher, + context: context, + errors: &errors, + ) + return true + }, + ) + if errors.isEmpty { return .success } else { diff --git a/SignalServiceKit/Backups/Archiving/Archivers/Recipient/BackupArchiveGroupRecipientArchiver.swift b/SignalServiceKit/Backups/Archiving/Archivers/Recipient/BackupArchiveGroupRecipientArchiver.swift index 757e85eb16..7f932d17c8 100644 --- a/SignalServiceKit/Backups/Archiving/Archivers/Recipient/BackupArchiveGroupRecipientArchiver.swift +++ b/SignalServiceKit/Backups/Archiving/Archivers/Recipient/BackupArchiveGroupRecipientArchiver.swift @@ -62,12 +62,12 @@ public class BackupArchiveGroupRecipientArchiver: BackupArchiveProtoStreamWriter do { try context.bencher.wrapEnumeration( - threadStore.enumerateGroupThreads(tx:block:), tx: context.tx, - ) { groupThread, frameBencher in - try Task.checkCancellation() - autoreleasepool { - self.archiveGroupThread( + enumerationBlock: { tx, block throws(CancellationError) in + try threadStore.enumerateGroupThreads(tx: tx, block: block) + }, + perEnumerantBlock: { [self] groupThread, frameBencher -> Bool in + archiveGroupThread( groupThread, blockedGroupIds: blockedGroupIds, stream: stream, @@ -75,15 +75,10 @@ public class BackupArchiveGroupRecipientArchiver: BackupArchiveProtoStreamWriter context: context, errors: &errors, ) - } - return true - } - } catch let error as CancellationError { - throw error - } catch { - // The enumeration of threads failed, not the processing of one single thread. - return .completeFailure(.fatalArchiveError(.threadIteratorError(error))) + return true + }, + ) } if errors.isEmpty { diff --git a/SignalServiceKit/Backups/Archiving/Archivers/Recipient/BackupArchiveRecipientStore.swift b/SignalServiceKit/Backups/Archiving/Archivers/Recipient/BackupArchiveRecipientStore.swift index d97f770cf1..2d14533945 100644 --- a/SignalServiceKit/Backups/Archiving/Archivers/Recipient/BackupArchiveRecipientStore.swift +++ b/SignalServiceKit/Backups/Archiving/Archivers/Recipient/BackupArchiveRecipientStore.swift @@ -3,6 +3,8 @@ // SPDX-License-Identifier: AGPL-3.0-only // +import GRDB + public class BackupArchiveRecipientStore { private let recipientTable: RecipientDatabaseTable @@ -20,13 +22,13 @@ public class BackupArchiveRecipientStore { func enumerateAllSignalRecipients( tx: DBReadTransaction, - block: (SignalRecipient) -> Void, - ) throws { - let cursor = try SignalRecipient.fetchCursor(tx.database) - while let next = try cursor.next() { - try Task.checkCancellation() - block(next) + block: (SignalRecipient) throws(CancellationError) -> Bool, + ) throws(CancellationError) { + var cursor = FailIfThrowsRecordCursor { + try SignalRecipient.fetchCursor(tx.database) } + + while let recipient = cursor.next(), try block(recipient) {} } func fetchRecipient( diff --git a/SignalServiceKit/Backups/Archiving/Archivers/StickerPack/BackupArchiveStickerPackArchiver.swift b/SignalServiceKit/Backups/Archiving/Archivers/StickerPack/BackupArchiveStickerPackArchiver.swift index 9436937639..0ebf04e717 100644 --- a/SignalServiceKit/Backups/Archiving/Archivers/StickerPack/BackupArchiveStickerPackArchiver.swift +++ b/SignalServiceKit/Backups/Archiving/Archivers/StickerPack/BackupArchiveStickerPackArchiver.swift @@ -47,15 +47,25 @@ public class BackupArchiveStickerPackArchiver: BackupArchiveProtoStreamWriter { context: BackupArchive.ArchivingContext, ) throws(CancellationError) -> ArchiveMultiFrameResult { var errors = [ArchiveFrameError]() - var handledPacks = Set() - func archiveInstalledStickerPack( - _ installedStickerPack: StickerPackRecord, - _ frameBencher: BackupArchive.Bencher.FrameBencher, - ) { - autoreleasepool { - guard !handledPacks.contains(installedStickerPack.packId) else { return } + // Iterate over installed sticker packs... + try context.bencher.wrapEnumeration( + tx: context.tx, + enumerationBlock: { tx, block throws(CancellationError) in + var cursor = FailIfThrowsRecordCursor { + try StickerPackRecord + .filter(Column(StickerPackRecord.CodingKeys.isInstalled) == true) + .fetchCursor(tx.database) + } + + while let stickerPack = cursor.next(), try block(stickerPack) {} + }, + perEnumerantBlock: { installedStickerPack, frameBencher -> Bool in + if handledPacks.contains(installedStickerPack.packId) { + return true + } + let maybeError: ArchiveFrameError? = Self.writeFrameToStream( stream, frameBencher: frameBencher, @@ -75,67 +85,45 @@ public class BackupArchiveStickerPackArchiver: BackupArchiveProtoStreamWriter { } else { handledPacks.insert(installedStickerPack.packId) } - } - } - func enumerateStickerPackRecord(tx: DBReadTransaction, block: (StickerPackRecord) throws -> Void) throws { - let cursor = try StickerPackRecord - .filter(Column(StickerPackRecord.CodingKeys.isInstalled) == true) - .fetchCursor(tx.database) - while let stickerPack = try cursor.next() { - try block(stickerPack) - } - } - - // Iterate over the installed sticker packs - do { - try context.bencher.wrapEnumeration( - enumerateStickerPackRecord(tx:block:), - tx: context.tx, - ) { stickerPack, frameBencher in - try Task.checkCancellation() - archiveInstalledStickerPack(stickerPack, frameBencher) - } - } catch let error as CancellationError { - throw error - } catch { - return .completeFailure(.fatalArchiveError(.stickerPackIteratorError(error))) - } + return true + }, + ) // Iterate over any restored sticker packs that have yet to be downloaded via StickerManager. - do { - try context.bencher.wrapEnumeration( - backupStickerPackDownloadStore.iterateAllEnqueued(tx:block:), - tx: context.tx, - ) { record, frameBencher in - try Task.checkCancellation() - autoreleasepool { - guard !handledPacks.contains(record.packId) else { return } - let maybeError: ArchiveFrameError? = Self.writeFrameToStream( - stream, - frameBencher: frameBencher, - ) { - var stickerPack = BackupProto_StickerPack() - stickerPack.packID = record.packId - stickerPack.packKey = record.packKey - - var frame = BackupProto_Frame() - frame.item = .stickerPack(stickerPack) - - return frame - } - if let maybeError { - errors.append(maybeError) - } else { - handledPacks.insert(record.packId) - } + try context.bencher.wrapEnumeration( + tx: context.tx, + enumerationBlock: { tx, block throws(CancellationError) in + try backupStickerPackDownloadStore.iterateAllEnqueued(tx: tx, block: block) + }, + perEnumerantBlock: { record, frameBencher -> Bool in + if handledPacks.contains(record.packId) { + return true } - } - } catch let error as CancellationError { - throw error - } catch { - return .completeFailure(.fatalArchiveError(.stickerPackIteratorError(error))) - } + + let maybeError: ArchiveFrameError? = Self.writeFrameToStream( + stream, + frameBencher: frameBencher, + ) { + var stickerPack = BackupProto_StickerPack() + stickerPack.packID = record.packId + stickerPack.packKey = record.packKey + + var frame = BackupProto_Frame() + frame.item = .stickerPack(stickerPack) + + return frame + } + + if let maybeError { + errors.append(maybeError) + } else { + handledPacks.insert(record.packId) + } + + return true + }, + ) if errors.count > 0 { return .partialSuccess(errors) @@ -155,15 +143,11 @@ public class BackupArchiveStickerPackArchiver: BackupArchiveProtoStreamWriter { _ stickerPack: BackupProto_StickerPack, context: BackupArchive.RestoringContext, ) -> RestoreFrameResult { - do { - try backupStickerPackDownloadStore.enqueue( - packId: stickerPack.packID, - packKey: stickerPack.packKey, - tx: context.tx, - ) - } catch { - return .failure([.restoreFrameError(.databaseInsertionFailed(error))]) - } + backupStickerPackDownloadStore.enqueue( + packId: stickerPack.packID, + packKey: stickerPack.packKey, + tx: context.tx, + ) return .success } } diff --git a/SignalServiceKit/Backups/Archiving/Archivers/StickerPack/BackupStickerPackDownloadStore.swift b/SignalServiceKit/Backups/Archiving/Archivers/StickerPack/BackupStickerPackDownloadStore.swift index 46ef10f517..96c66ff375 100644 --- a/SignalServiceKit/Backups/Archiving/Archivers/StickerPack/BackupStickerPackDownloadStore.swift +++ b/SignalServiceKit/Backups/Archiving/Archivers/StickerPack/BackupStickerPackDownloadStore.swift @@ -9,90 +9,68 @@ import GRDB /// a backup, but whose full data has not been downloaded. /// Post-restore, items listed here will be asynchronously passed to /// StickerManager, downloaded, and persisted as usable StickerPack objects. -public protocol BackupStickerPackDownloadStore { +public struct BackupStickerPackDownloadStore { + + private typealias Record = QueuedBackupStickerPackDownload /// "Enqueue" a sticker pack from a backup for download. /// Doesn't actually trigger a download; this is delegated to the TaskQueueLoader /// in StickerManager - func enqueue( - packId: Data, - packKey: Data, - tx: DBWriteTransaction, - ) throws + public func enqueue(packId: Data, packKey: Data, tx: DBWriteTransaction) { + failIfThrows { + // If this record is already in the queue, don't insert a second copy + if + let _ = try QueuedAttachmentDownloadRecord + .filter(Column(Record.CodingKeys.packId) == packId) + .fetchOne(tx.database) + { + return + } + + var record = Record(packId: packId, packKey: packKey) + try record.insert(tx.database) + } + } /// Read rows off the queue one by one, calling the block for each. - func iterateAllEnqueued( - tx: DBReadTransaction, - block: ( - QueuedBackupStickerPackDownload, - ) throws -> Void, - ) throws - - /// Return the top `count` rows of the download queue. - func peek( - count: UInt, - tx: DBReadTransaction, - ) throws -> [QueuedBackupStickerPackDownload] - - /// Remove the record from the download queue. - func removeRecordFromQueue( - record: QueuedBackupStickerPackDownload, - tx: DBWriteTransaction, - ) throws -} - -public class BackupStickerPackDownloadStoreImpl: BackupStickerPackDownloadStore { - - public typealias Record = QueuedBackupStickerPackDownload - - public func enqueue(packId: Data, packKey: Data, tx: DBWriteTransaction) throws { - let db = tx.database - var record = Record(packId: packId, packKey: packKey) - - // If this record is already in the queue, don't insert a second copy - if - let _ = try QueuedAttachmentDownloadRecord - .filter(Column(Record.CodingKeys.packId) == packId) - .fetchOne(db) - { - return - } - - try record.insert(db) - } - + /// - Parameter block + /// A block executed for each enumerated record. Returns `true` if + /// enumeration should continue, and `false` otherwise. public func iterateAllEnqueued( tx: DBReadTransaction, - block: (QueuedBackupStickerPackDownload) throws -> Void, - ) throws { - let db = tx.database - let cursor = try Record - .order([Column(Record.CodingKeys.id).desc]) - .fetchCursor(db) - - while let record = try cursor.next() { - try block(record) + block: (QueuedBackupStickerPackDownload) throws(CancellationError) -> Bool, + ) throws(CancellationError) { + var cursor = FailIfThrowsRecordCursor { + try Record + .order([Column(Record.CodingKeys.id).desc]) + .fetchCursor(tx.database) } + + while let record = cursor.next(), try block(record) {} } + /// Return the top `count` rows of the download queue. public func peek( count: UInt, tx: DBReadTransaction, - ) throws -> [QueuedBackupStickerPackDownload] { - let db = tx.database - return try Record - .order([Column(Record.CodingKeys.id).asc]) - .limit(Int(count)) - .fetchAll(db) + ) -> [QueuedBackupStickerPackDownload] { + return failIfThrows { + try Record + .order([Column(Record.CodingKeys.id).asc]) + .limit(Int(count)) + .fetchAll(tx.database) + } } + /// Remove the record from the download queue. public func removeRecordFromQueue( record: QueuedBackupStickerPackDownload, tx: DBWriteTransaction, - ) throws { - let db = tx.database - try Record - .filter(Column(Record.CodingKeys.id) == record.id) - .deleteAll(db) + ) { + failIfThrows { + try Record + .filter(Column(Record.CodingKeys.id) == record.id) + .deleteAll(tx.database) + } } } diff --git a/SignalServiceKit/Backups/Archiving/BackupArchive+Bench.swift b/SignalServiceKit/Backups/Archiving/BackupArchive+Bench.swift index d851b7b083..9c4c1073b0 100644 --- a/SignalServiceKit/Backups/Archiving/BackupArchive+Bench.swift +++ b/SignalServiceKit/Backups/Archiving/BackupArchive+Bench.swift @@ -10,15 +10,26 @@ extension BackupArchive { /// A `Bencher` specialized for measuring Backup archiving. class ArchiveBencher: Bencher { - /// Given a block that does an enumeration over db objects, wraps that enumeration to instead take - /// a closure with a FrameBencher that also measures the time spent enumerating. - func wrapEnumeration( - _ enumerationFunc: (DBReadTransaction, (EnumeratedInput) throws -> Output) throws -> Void, + /// Wrap the given enumeration method to facilitate measurement of the + /// time spent. + /// + /// - Parameter enumerationBlock + /// A block that enumerates models and calls the block it is passed for + /// each model. + /// - Parameter perEnumerantBlock + /// A block called once per enumerated model. Returns `true` if + /// enumeration should continue; `false` otherwise. + func wrapEnumeration( tx: DBReadTransaction, - enumerationBlock: @escaping (EnumeratedInput, FrameBencher) throws -> Output, - ) rethrows { + enumerationBlock: (DBReadTransaction, (Enumerant) throws(CancellationError) -> Bool) throws(CancellationError) -> Void, + perEnumerantBlock: @escaping (Enumerant, FrameBencher) -> Bool, + ) throws(CancellationError) { var enumerationStepStartDate = dateProvider() - try enumerationFunc(tx) { enumeratedInput throws in + try enumerationBlock(tx) { enumeratedInput throws(CancellationError) -> Bool in + if Task.isCancelled { + throw CancellationError() + } + defer { // A little cheating - the "end" of this step is the "start" // of the next one. @@ -31,32 +42,9 @@ extension BackupArchive { enumerationStepStartDate: enumerationStepStartDate, ) - return try enumerationBlock(enumeratedInput, frameBencher) - } - } - - /// Variant of the above where the block doesn't throw; unfortunately `rethrows` - /// can't cover two layers of throws variations. - func wrapEnumeration( - _ enumerationFunc: (DBReadTransaction, (EnumeratedInput) -> Output) throws -> Void, - tx: DBReadTransaction, - enumerationBlock: @escaping (EnumeratedInput, FrameBencher) -> Output, - ) rethrows { - var enumerationStepStartDate = dateProvider() - try enumerationFunc(tx) { enumeratedInput in - defer { - // A little cheating - the "end" of this step is the "start" - // of the next one. - enumerationStepStartDate = dateProvider() + return autoreleasepool { + perEnumerantBlock(enumeratedInput, frameBencher) } - - let frameBencher = FrameBencher( - bencher: self, - dateProvider: dateProvider, - enumerationStepStartDate: enumerationStepStartDate, - ) - - return enumerationBlock(enumeratedInput, frameBencher) } } } diff --git a/SignalServiceKit/Backups/Archiving/BackupArchive+Shims.swift b/SignalServiceKit/Backups/Archiving/BackupArchive+Shims.swift index ff01613637..c7d1129ae3 100644 --- a/SignalServiceKit/Backups/Archiving/BackupArchive+Shims.swift +++ b/SignalServiceKit/Backups/Archiving/BackupArchive+Shims.swift @@ -255,7 +255,10 @@ public class _MessageBackup_PreferencesWrapper: _MessageBackup_PreferencesShim { public protocol _MessageBackup_ProfileManagerShim { - func enumerateUserProfiles(tx: DBReadTransaction, block: (OWSUserProfile) -> Void) + func enumerateUserProfiles( + tx: DBReadTransaction, + block: (OWSUserProfile) throws(CancellationError) -> Bool, + ) throws(CancellationError) -> Void func getUserProfile(for address: SignalServiceAddress, tx: DBReadTransaction) -> OWSUserProfile? @@ -295,10 +298,15 @@ public class _MessageBackup_ProfileManagerWrapper: _MessageBackup_ProfileManager self.profileManager = profileManager } - public func enumerateUserProfiles(tx: DBReadTransaction, block: (OWSUserProfile) -> Void) { - OWSUserProfile.anyEnumerate(transaction: tx) { profile, _ in - block(profile) + public func enumerateUserProfiles( + tx: DBReadTransaction, + block: (OWSUserProfile) throws(CancellationError) -> Bool, + ) throws(CancellationError) { + var cursor = FailIfThrowsRecordCursor { + try OWSUserProfile.fetchCursor(tx.database) } + + while let profile = cursor.next(), try block(profile) {} } public func getUserProfile(for address: SignalServiceAddress, tx: DBReadTransaction) -> OWSUserProfile? { diff --git a/SignalServiceKit/Calls/CallLinkRecordStore.swift b/SignalServiceKit/Calls/CallLinkRecordStore.swift index 7f10dccbc9..0c4e945002 100644 --- a/SignalServiceKit/Calls/CallLinkRecordStore.swift +++ b/SignalServiceKit/Calls/CallLinkRecordStore.swift @@ -33,17 +33,19 @@ public struct CallLinkRecordStore { expiration: Int64?, isUpcoming: Bool?, tx: DBWriteTransaction, - ) throws -> CallLinkRecord { - return try CallLinkRecord.insertFromBackup( - rootKey: rootKey, - adminPasskey: adminPasskey, - name: name, - restrictions: restrictions, - revoked: revoked, - expiration: expiration, - isUpcoming: isUpcoming, - tx: tx, - ) + ) -> CallLinkRecord { + return failIfThrows { + try CallLinkRecord.insertFromBackup( + rootKey: rootKey, + adminPasskey: adminPasskey, + name: name, + restrictions: restrictions, + revoked: revoked, + expiration: expiration, + isUpcoming: isUpcoming, + tx: tx, + ) + } } public func fetchOrInsert(rootKey: CallLinkRootKey, tx: DBWriteTransaction) -> (record: CallLinkRecord, inserted: Bool) { @@ -89,15 +91,19 @@ public struct CallLinkRecordStore { } } - public func enumerateAll(tx: DBReadTransaction, block: (CallLinkRecord) throws -> Void) throws { - do { - let cursor = try CallLinkRecord.fetchCursor(tx.database) - while let next = try cursor.next() { - try block(next) - } - } catch { - throw error.grdbErrorForLogging + /// Enumerate all `CallLinkRecord`s. + /// - Parameter block + /// A block executed for each enumerated record. Returns `true` if + /// enumeration should continue, and `false` otherwise. + public func enumerateAll( + tx: DBReadTransaction, + block: (CallLinkRecord) throws(CancellationError) -> Bool, + ) throws(CancellationError) { + var cursor = FailIfThrowsRecordCursor { + try CallLinkRecord.fetchCursor(tx.database) } + + while let record = cursor.next(), try block(record) {} } public func fetchUpcoming(earlierThan expirationTimestamp: Date?, limit: Int, tx: DBReadTransaction) -> [CallLinkRecord] { diff --git a/SignalServiceKit/Calls/CallRecord/CallRecordStore.swift b/SignalServiceKit/Calls/CallRecord/CallRecordStore.swift index 0a9a1616a8..b36c212951 100644 --- a/SignalServiceKit/Calls/CallRecord/CallRecordStore.swift +++ b/SignalServiceKit/Calls/CallRecord/CallRecordStore.swift @@ -126,10 +126,13 @@ public protocol CallRecordStore { ) /// Enumerate all ad hoc call records. + /// - Parameter block + /// A block executed for each enumerated record. Returns `true` if + /// enumeration should continue, and `false` otherwise. func enumerateAdHocCallRecords( tx: DBReadTransaction, - block: (CallRecord) throws -> Void, - ) throws + block: (CallRecord) throws(CancellationError) -> Bool, + ) throws(CancellationError) /// Fetch the record for the given call ID in the given thread, if one /// exists. @@ -279,6 +282,19 @@ class CallRecordStoreImpl: CallRecordStore { } } + func enumerateAdHocCallRecords( + tx: DBReadTransaction, + block: (CallRecord) throws(CancellationError) -> Bool, + ) throws(CancellationError) { + var cursor = FailIfThrowsRecordCursor { + return try CallRecord + .filter(Column(CallRecord.CodingKeys.callType) == CallRecord.CallType.adHocCall.rawValue) + .fetchCursor(tx.database) + } + + while let record = cursor.next(), try block(record) {} + } + func fetch( callId: UInt64, conversationId: CallRecord.ConversationID, @@ -408,22 +424,6 @@ class CallRecordStoreImpl: CallRecordStore { } } - func enumerateAdHocCallRecords( - tx: DBReadTransaction, - block: (CallRecord) throws -> Void, - ) throws { - do { - let cursor = try CallRecord - .filter(Column(CallRecord.CodingKeys.callType) == CallRecord.CallType.adHocCall.rawValue) - .fetchCursor(tx.database) - while let value = try cursor.next() { - try block(value) - } - } catch { - throw error.grdbErrorForLogging - } - } - fileprivate func compileQuery( columnArgs: [(CallRecord.CodingKeys, DatabaseValueConvertible)], limit: Int? = nil, diff --git a/SignalServiceKit/Environment/AppSetup.swift b/SignalServiceKit/Environment/AppSetup.swift index ab1bb58332..b28ee77c9f 100644 --- a/SignalServiceKit/Environment/AppSetup.swift +++ b/SignalServiceKit/Environment/AppSetup.swift @@ -1387,7 +1387,7 @@ extension AppSetup.GlobalsContinuation { recipientTable: recipientDatabaseTable, searchableNameIndexer: searchableNameIndexer, ) - let backupStickerPackDownloadStore = BackupStickerPackDownloadStoreImpl() + let backupStickerPackDownloadStore = BackupStickerPackDownloadStore() let backupStoryStore = BackupArchiveStoryStore( storyStore: storyStore, storyRecipientStore: storyRecipientStore, diff --git a/SignalServiceKit/Messages/Interactions/TSMessage.swift b/SignalServiceKit/Messages/Interactions/TSMessage.swift index 8126573c18..82afa5bab6 100644 --- a/SignalServiceKit/Messages/Interactions/TSMessage.swift +++ b/SignalServiceKit/Messages/Interactions/TSMessage.swift @@ -198,9 +198,10 @@ public extension TSMessage { let newReaction = OWSReaction( uniqueMessageId: uniqueId, emoji: emoji, - reactor: reactor, + reactorAci: reactor, + reactorPhoneNumber: nil, sentAtTimestamp: sentAtTimestamp, - receivedAtTimestamp: receivedAtTimestamp, + sortOrder: receivedAtTimestamp, ) newReaction.anyInsert(transaction: tx) diff --git a/SignalServiceKit/Messages/Reactions/OWSReaction.swift b/SignalServiceKit/Messages/Reactions/OWSReaction.swift index 98d888cc51..acd7ad88bf 100644 --- a/SignalServiceKit/Messages/Reactions/OWSReaction.swift +++ b/SignalServiceKit/Messages/Reactions/OWSReaction.swift @@ -55,27 +55,7 @@ public final class OWSReaction: NSObject, SDSCodableModel, Decodable, NSSecureCo SignalServiceAddress.legacyAddress(serviceId: reactorAci, phoneNumber: reactorPhoneNumber) } - /// Note that we initialize with a receivedAtTimestamp, but should make no assumptions - /// that the sortOrder is always a timestamp at read time. Backups use sortOrders that - /// may not be timestamps. - public convenience init( - uniqueMessageId: String, - emoji: String, - reactor: Aci, - sentAtTimestamp: UInt64, - receivedAtTimestamp: UInt64, - ) { - self.init( - uniqueMessageId: uniqueMessageId, - emoji: emoji, - reactorAci: reactor, - reactorPhoneNumber: nil, - sentAtTimestamp: sentAtTimestamp, - sortOrder: receivedAtTimestamp, - ) - } - - private init( + init( uniqueMessageId: String, emoji: String, reactorAci: Aci?, @@ -93,40 +73,6 @@ public final class OWSReaction: NSObject, SDSCodableModel, Decodable, NSSecureCo self.read = false } - public static func fromRestoredBackup( - uniqueMessageId: String, - emoji: String, - reactorAci: Aci, - sentAtTimestamp: UInt64, - sortOrder: UInt64, - ) -> Self { - return Self( - uniqueMessageId: uniqueMessageId, - emoji: emoji, - reactorAci: reactorAci, - reactorPhoneNumber: nil, - sentAtTimestamp: sentAtTimestamp, - sortOrder: sortOrder, - ) - } - - public static func fromRestoredBackup( - uniqueMessageId: String, - emoji: String, - reactorE164: E164, - sentAtTimestamp: UInt64, - sortOrder: UInt64, - ) -> OWSReaction { - return .init( - uniqueMessageId: uniqueMessageId, - emoji: emoji, - reactorAci: nil, - reactorPhoneNumber: reactorE164.stringValue, - sentAtTimestamp: sentAtTimestamp, - sortOrder: sortOrder, - ) - } - public func markAsRead(transaction: DBWriteTransaction) { anyUpdate(transaction: transaction) { reaction in reaction.read = true diff --git a/SignalServiceKit/Messages/Stickers/StickerManager.swift b/SignalServiceKit/Messages/Stickers/StickerManager.swift index 3e9efe050f..ffe6103f2b 100644 --- a/SignalServiceKit/Messages/Stickers/StickerManager.swift +++ b/SignalServiceKit/Messages/Stickers/StickerManager.swift @@ -91,7 +91,7 @@ public class StickerManager: NSObject { db: DependenciesBridge.shared.db, runner: StickerPackDownloadTaskRunner( store: StickerPackDownloadTaskRecordStore( - store: BackupStickerPackDownloadStoreImpl(), + store: BackupStickerPackDownloadStore(), ), ), ) @@ -1195,14 +1195,14 @@ public class StickerManager: NSObject { self.store = store } - func peek(count: UInt, tx: DBReadTransaction) throws -> [StickerPackDownloadTaskRecord] { - return try store.peek(count: count, tx: tx).map { + func peek(count: UInt, tx: DBReadTransaction) -> [StickerPackDownloadTaskRecord] { + return store.peek(count: count, tx: tx).map { return .init(id: $0.id!, record: $0) } } - func removeRecord(_ record: StickerPackDownloadTaskRecord, tx: DBWriteTransaction) throws { - try store.removeRecordFromQueue(record: record.record, tx: tx) + func removeRecord(_ record: StickerPackDownloadTaskRecord, tx: DBWriteTransaction) { + store.removeRecordFromQueue(record: record.record, tx: tx) } } diff --git a/SignalServiceKit/Mocks/CallRecord/MockCallRecordStore.swift b/SignalServiceKit/Mocks/CallRecord/MockCallRecordStore.swift index ce3fc0a14a..38cd60a773 100644 --- a/SignalServiceKit/Mocks/CallRecord/MockCallRecordStore.swift +++ b/SignalServiceKit/Mocks/CallRecord/MockCallRecordStore.swift @@ -48,10 +48,11 @@ class MockCallRecordStore: CallRecordStore { }) } - func enumerateAdHocCallRecords(tx: DBReadTransaction, block: (CallRecord) throws -> Void) throws { - try callRecords.forEach { record in - guard record.callType == .adHocCall else { return } - try block(record) + func enumerateAdHocCallRecords(tx: DBReadTransaction, block: (CallRecord) throws(CancellationError) -> Bool) throws(CancellationError) { + for record in callRecords.filter({ $0.callType == .adHocCall }) { + guard try block(record) else { + return + } } } diff --git a/SignalServiceKit/Storage/Database/FailIfThrowsRecordCursor.swift b/SignalServiceKit/Storage/Database/FailIfThrowsRecordCursor.swift new file mode 100644 index 0000000000..60d9e7fdbe --- /dev/null +++ b/SignalServiceKit/Storage/Database/FailIfThrowsRecordCursor.swift @@ -0,0 +1,22 @@ +// +// Copyright 2026 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only +// + +import GRDB + +/// A convenience wrapper for `GRDB.RecordCursor` that swallows errors using +/// `failIfThrows` and adds `Sequence` conformance. +struct FailIfThrowsRecordCursor: IteratorProtocol, Sequence { + typealias Element = T + + private let recordCursor: RecordCursor + + init(makeCursorBlock: () throws -> RecordCursor) { + self.recordCursor = failIfThrows(block: makeCursorBlock) + } + + mutating func next() -> T? { + return failIfThrows(block: recordCursor.next) + } +} diff --git a/SignalServiceKit/Storage/Database/Records/ThreadFinder.swift b/SignalServiceKit/Storage/Database/Records/ThreadFinder.swift index 8e40dbb000..774d1c22ea 100644 --- a/SignalServiceKit/Storage/Database/Records/ThreadFinder.swift +++ b/SignalServiceKit/Storage/Database/Records/ThreadFinder.swift @@ -54,76 +54,74 @@ public class ThreadFinder { /// - Parameter block /// A block executed for each enumerated thread. Returns `true` if /// enumeration should continue, and `false` otherwise. - public func enumerateStoryThreads( - transaction: DBReadTransaction, - block: (TSPrivateStoryThread) throws -> Bool, - ) throws { + public func enumerateStoryThreads( + tx: DBReadTransaction, + block: (TSPrivateStoryThread) throws(E) -> Bool, + ) throws(E) { let sql = """ SELECT * FROM \(TSThread.databaseTableName) WHERE \(threadColumn: .recordType) = \(SDSRecordType.privateStoryThread.rawValue) """ - let cursor = try TSPrivateStoryThread.fetchCursor( - transaction.database, - sql: sql, - ) - while let storyThread = try cursor.next() { - guard try block(storyThread) else { - break - } + + var cursor = FailIfThrowsRecordCursor { + try TSPrivateStoryThread.fetchCursor( + tx.database, + sql: sql, + ) } + + while let storyThread = cursor.next(), try block(storyThread) {} } /// Enumerates group threads in "last interaction" order. /// - Parameter block /// A block executed for each enumerated thread. Returns `true` if /// enumeration should continue, and `false` otherwise. - public func enumerateGroupThreads( - transaction: DBReadTransaction, - block: (TSGroupThread) throws -> Bool, - ) throws { + public func enumerateGroupThreads( + tx: DBReadTransaction, + block: (TSGroupThread) throws(E) -> Bool, + ) throws(E) { let sql = """ SELECT * FROM \(TSThread.databaseTableName) - WHERE \(groupThreadColumn: .groupModel) IS NOT NULL + WHERE \(threadColumn: .recordType) = \(SDSRecordType.groupThread.rawValue) ORDER BY \(threadColumn: .lastInteractionRowId) DESC """ - let cursor = try TSThread.fetchCursor( - transaction.database, - sql: sql, - ) - while let threadRecord = try cursor.next() { - guard let groupThread = threadRecord as? TSGroupThread else { - owsFailDebug("Skipping thread that's not a group.") - continue - } - guard try block(groupThread) else { - break - } + var cursor = FailIfThrowsRecordCursor { + return try TSGroupThread.fetchCursor( + tx.database, + sql: sql, + ) } + + while let groupThread = cursor.next(), try block(groupThread) {} } /// Enumerates all non-story threads in arbitrary order. /// - Parameter block /// A block executed for each enumerated thread. Returns `true` if /// enumeration should continue, and `false` otherwise. - public func enumerateNonStoryThreads( - transaction: DBReadTransaction, - block: (TSThread) throws -> Bool, - ) throws { + public func enumerateNonStoryThreads( + tx: DBReadTransaction, + block: (TSThread) throws(E) -> Bool, + ) throws(E) { let sql = """ SELECT * FROM \(TSThread.databaseTableName) WHERE \(threadColumn: .recordType) IS NOT ? """ - let cursor = try TSThread.fetchCursor( - transaction.database, - sql: sql, - arguments: [SDSRecordType.privateStoryThread.rawValue], - ) - while let thread = try cursor.next(), try block(thread) {} + var cursor = FailIfThrowsRecordCursor { + return try TSThread.fetchCursor( + tx.database, + sql: sql, + arguments: [SDSRecordType.privateStoryThread.rawValue], + ) + } + + while let thread = cursor.next(), try block(thread) {} } public func visibleThreadCount( diff --git a/SignalServiceKit/StorageService/StorageServiceUnknownFieldMigrator.swift b/SignalServiceKit/StorageService/StorageServiceUnknownFieldMigrator.swift index 54f4a32203..aeac1f9721 100644 --- a/SignalServiceKit/StorageService/StorageServiceUnknownFieldMigrator.swift +++ b/SignalServiceKit/StorageService/StorageServiceUnknownFieldMigrator.swift @@ -173,7 +173,7 @@ public class StorageServiceUnknownFieldMigrator { recordMap[groupId] = $0.dontNotifyForMentionsIfMuted } } - try? ThreadFinder().enumerateGroupThreads(transaction: tx) { groupThread -> Bool in + ThreadFinder().enumerateGroupThreads(tx: tx) { groupThread -> Bool in let remoteValue: TSThreadMentionNotificationMode = (recordMap[groupThread.groupId] ?? false) ? .never : .always if isPrimaryDevice { diff --git a/SignalServiceKit/Threads/ThreadStore.swift b/SignalServiceKit/Threads/ThreadStore.swift index 8353bfd688..2bc7b941e8 100644 --- a/SignalServiceKit/Threads/ThreadStore.swift +++ b/SignalServiceKit/Threads/ThreadStore.swift @@ -6,30 +6,30 @@ public import LibSignalClient public protocol ThreadStore { - /// Covers contact and group threads. + /// Enumerate all threads other than `TSPrivateStoryThread`. /// - Parameter block /// A block executed for each enumerated thread. Returns `true` if /// enumeration should continue, and `false` otherwise. - func enumerateNonStoryThreads( + func enumerateNonStoryThreads( tx: DBReadTransaction, - block: (TSThread) throws -> Bool, - ) throws + block: (TSThread) throws(E) -> Bool, + ) throws(E) /// Enumerates story distribution lists /// - Parameter block /// A block executed for each enumerated thread. Returns `true` if /// enumeration should continue, and `false` otherwise. - func enumerateStoryThreads( + func enumerateStoryThreads( tx: DBReadTransaction, - block: (TSPrivateStoryThread) throws -> Bool, - ) throws + block: (TSPrivateStoryThread) throws(E) -> Bool, + ) throws(E) /// Enumerates group threads in "last interaction" order. /// - Parameter block /// A block executed for each enumerated thread. Returns `true` if /// enumeration should continue, and `false` otherwise. - func enumerateGroupThreads( + func enumerateGroupThreads( tx: DBReadTransaction, - block: (TSGroupThread) throws -> Bool, - ) throws + block: (TSGroupThread) throws(E) -> Bool, + ) throws(E) func fetchThread(rowId: Int64, tx: DBReadTransaction) -> TSThread? func fetchThread(uniqueId: String, tx: DBReadTransaction) -> TSThread? func fetchContactThreads(serviceId: ServiceId, tx: DBReadTransaction) -> [TSContactThread] @@ -153,16 +153,16 @@ public class ThreadStoreImpl: ThreadStore { public init() {} - public func enumerateNonStoryThreads(tx: DBReadTransaction, block: (TSThread) throws -> Bool) throws { - return try ThreadFinder().enumerateNonStoryThreads(transaction: tx, block: block) + public func enumerateNonStoryThreads(tx: DBReadTransaction, block: (TSThread) throws(E) -> Bool) throws(E) { + return try ThreadFinder().enumerateNonStoryThreads(tx: tx, block: block) } - public func enumerateStoryThreads(tx: DBReadTransaction, block: (TSPrivateStoryThread) throws -> Bool) throws { - return try ThreadFinder().enumerateStoryThreads(transaction: tx, block: block) + public func enumerateStoryThreads(tx: DBReadTransaction, block: (TSPrivateStoryThread) throws(E) -> Bool) throws(E) { + return try ThreadFinder().enumerateStoryThreads(tx: tx, block: block) } - public func enumerateGroupThreads(tx: DBReadTransaction, block: (TSGroupThread) throws -> Bool) throws { - return try ThreadFinder().enumerateGroupThreads(transaction: tx, block: block) + public func enumerateGroupThreads(tx: DBReadTransaction, block: (TSGroupThread) throws(E) -> Bool) throws(E) { + return try ThreadFinder().enumerateGroupThreads(tx: tx, block: block) } public func fetchThread(rowId: Int64, tx: DBReadTransaction) -> TSThread? { @@ -289,7 +289,7 @@ public class MockThreadStore: ThreadStore { private(set) var threads = [TSThread]() public var nextRowId: Int64 = 1 - public func enumerateNonStoryThreads(tx: DBReadTransaction, block: (TSThread) throws -> Bool) throws { + public func enumerateNonStoryThreads(tx: DBReadTransaction, block: (TSThread) throws(E) -> Bool) throws(E) { for thread in threads { guard !(thread is TSPrivateStoryThread) else { continue @@ -300,7 +300,7 @@ public class MockThreadStore: ThreadStore { } } - public func enumerateStoryThreads(tx: DBReadTransaction, block: (TSPrivateStoryThread) throws -> Bool) throws { + public func enumerateStoryThreads(tx: DBReadTransaction, block: (TSPrivateStoryThread) throws(E) -> Bool) throws(E) { for thread in threads { guard let storyThread = thread as? TSPrivateStoryThread else { continue @@ -311,7 +311,7 @@ public class MockThreadStore: ThreadStore { } } - public func enumerateGroupThreads(tx: DBReadTransaction, block: (TSGroupThread) throws -> Bool) throws { + public func enumerateGroupThreads(tx: DBReadTransaction, block: (TSGroupThread) throws(E) -> Bool) throws(E) { for thread in threads { guard let groupThread = thread as? TSGroupThread else { continue diff --git a/SignalServiceKit/tests/Storage/Database/DatabaseRecoveryTest.swift b/SignalServiceKit/tests/Storage/Database/DatabaseRecoveryTest.swift index eabf374369..17ea2c3023 100644 --- a/SignalServiceKit/tests/Storage/Database/DatabaseRecoveryTest.swift +++ b/SignalServiceKit/tests/Storage/Database/DatabaseRecoveryTest.swift @@ -152,9 +152,10 @@ final class DatabaseRecoveryTest: SSKBaseTest { let reaction = OWSReaction( uniqueMessageId: message.uniqueId, emoji: "💽", - reactor: localAci, + reactorAci: localAci, + reactorPhoneNumber: nil, sentAtTimestamp: 1234, - receivedAtTimestamp: 1234, + sortOrder: 1234, ) reaction.anyInsert(transaction: transaction)