diff --git a/SignalServiceKit/Messages/MessageRequestPendingReceipts.swift b/SignalServiceKit/Messages/MessageRequestPendingReceipts.swift index e98962745a..c7e938b91d 100644 --- a/SignalServiceKit/Messages/MessageRequestPendingReceipts.swift +++ b/SignalServiceKit/Messages/MessageRequestPendingReceipts.swift @@ -24,73 +24,63 @@ public class MessageRequestPendingReceipts: PendingReceiptRecorder { // MARK: - - let finder = PendingReceiptFinder() + private let finder = PendingReceiptFinder() // MARK: - public func recordPendingReadReceipt(for message: TSIncomingMessage, thread: TSThread, transaction: DBWriteTransaction) { - do { - try finder.recordPendingReadReceipt(for: message, thread: thread, transaction: transaction) - } catch { - owsFailDebug("error: \(error)") + guard let threadId = thread.sqliteRowId else { + owsFail("can't record pending receipt without thread id") } + finder.recordPendingReadReceipt(for: message, threadId: threadId, transaction: transaction) } public func recordPendingViewedReceipt(for message: TSIncomingMessage, thread: TSThread, transaction: DBWriteTransaction) { - do { - try finder.recordPendingViewedReceipt(for: message, thread: thread, transaction: transaction) - } catch { - owsFailDebug("error: \(error)") + guard let threadId = thread.sqliteRowId else { + owsFail("can't record pending receipt without thread id") } + finder.recordPendingViewedReceipt(for: message, threadId: threadId, transaction: transaction) } // MARK: - @objc private func profileWhitelistDidChange(notification: Notification) { - do { - try SSKEnvironment.shared.databaseStorageRef.grdbStorage.read { transaction in - guard let thread = notification.affectedThread(transaction: transaction) else { - return - } - let userProfileWriter = notification.userProfileWriter - if userProfileWriter == .localUser { - try self.sendAnyReadyReceipts(threads: [thread], transaction: transaction) - } else { - try self.removeAnyReadyReceipts(threads: [thread], transaction: transaction) - } + SSKEnvironment.shared.databaseStorageRef.read { transaction in + guard let threadId = notification.affectedThread(transaction: transaction)?.sqliteRowId else { + return + } + let userProfileWriter = notification.userProfileWriter + if userProfileWriter == .localUser { + self.sendAnyReadyReceipts(threadIds: [threadId], transaction: transaction) + } else { + self.removeAnyReadyReceipts(threadIds: [threadId], transaction: transaction) } - } catch { - owsFailDebug("error: \(error)") } } private func auditPendingReceipts() { - do { - try SSKEnvironment.shared.databaseStorageRef.grdbStorage.read { transaction in - let threads = try self.finder.threadsWithPendingReceipts(transaction: transaction) - try self.sendAnyReadyReceipts(threads: threads, transaction: transaction) - } - } catch { - owsFailDebug("error: \(error)") + SSKEnvironment.shared.databaseStorageRef.read { transaction in + let threadIds = self.finder.threadIdsWithPendingReceipts(transaction: transaction) + self.sendAnyReadyReceipts(threadIds: threadIds, transaction: transaction) } } - private func sendAnyReadyReceipts(threads: [TSThread], transaction: DBReadTransaction) throws { - let pendingReadReceipts: [PendingReadReceiptRecord] = try threads.flatMap { thread -> [PendingReadReceiptRecord] in + private func sendAnyReadyReceipts(threadIds: some Sequence, transaction: DBReadTransaction) { + var pendingReadReceipts = [PendingReadReceiptRecord]() + var pendingViewedReceipts = [PendingViewedReceiptRecord]() + + for threadId in threadIds { + guard let thread = ThreadFinder().fetch(rowId: threadId, tx: transaction) else { + // The thread may be missing because there's no foreign key relationship. + continue + } guard !thread.hasPendingMessageRequest(transaction: transaction) else { - return [] + continue } - return try self.finder.pendingReadReceipts(thread: thread, transaction: transaction) - } - - let pendingViewedReceipts: [PendingViewedReceiptRecord] = try threads.flatMap { thread -> [PendingViewedReceiptRecord] in - guard !thread.hasPendingMessageRequest(transaction: transaction) else { - return [] - } - - return try self.finder.pendingViewedReceipts(thread: thread, transaction: transaction) + pendingReadReceipts.append(contentsOf: self.finder.pendingReadReceipts(threadId: threadId, transaction: transaction)) + pendingViewedReceipts.append(contentsOf: self.finder.pendingViewedReceipts(threadId: threadId, transaction: transaction)) } guard !pendingReadReceipts.isEmpty || !pendingViewedReceipts.isEmpty else { @@ -106,21 +96,21 @@ public class MessageRequestPendingReceipts: PendingReceiptRecorder { } } - private func removeAnyReadyReceipts(threads: [TSThread], transaction: DBReadTransaction) throws { - let pendingReadReceipts: [PendingReadReceiptRecord] = try threads.flatMap { thread -> [PendingReadReceiptRecord] in + private func removeAnyReadyReceipts(threadIds: some Sequence, transaction: DBReadTransaction) { + var pendingReadReceipts = [PendingReadReceiptRecord]() + var pendingViewedReceipts = [PendingViewedReceiptRecord]() + + for threadId in threadIds { + guard let thread = ThreadFinder().fetch(rowId: threadId, tx: transaction) else { + // The thread may be missing because there's no foreign key relationship. + continue + } guard !thread.hasPendingMessageRequest(transaction: transaction) else { - return [] + continue } - return try self.finder.pendingReadReceipts(thread: thread, transaction: transaction) - } - - let pendingViewedReceipts: [PendingViewedReceiptRecord] = try threads.flatMap { thread -> [PendingViewedReceiptRecord] in - guard !thread.hasPendingMessageRequest(transaction: transaction) else { - return [] - } - - return try self.finder.pendingViewedReceipts(thread: thread, transaction: transaction) + pendingReadReceipts.append(contentsOf: self.finder.pendingReadReceipts(threadId: threadId, transaction: transaction)) + pendingViewedReceipts.append(contentsOf: self.finder.pendingViewedReceipts(threadId: threadId, transaction: transaction)) } guard !pendingReadReceipts.isEmpty || !pendingViewedReceipts.isEmpty else { @@ -128,19 +118,15 @@ public class MessageRequestPendingReceipts: PendingReceiptRecorder { } SSKEnvironment.shared.databaseStorageRef.asyncWrite { transaction in - do { - try self.finder.delete(pendingReadReceipts: pendingReadReceipts, transaction: transaction) - try self.finder.delete(pendingViewedReceipts: pendingViewedReceipts, transaction: transaction) - } catch { - owsFailDebug("error: \(error)") - } + self.finder.delete(pendingReadReceipts: pendingReadReceipts, transaction: transaction) + self.finder.delete(pendingViewedReceipts: pendingViewedReceipts, transaction: transaction) } } private func enqueue(pendingReadReceipts: [PendingReadReceiptRecord], pendingViewedReceipts: [PendingViewedReceiptRecord], transaction: DBWriteTransaction) throws { guard OWSReceiptManager.areReadReceiptsEnabled(transaction: transaction) else { Logger.info("Deleting all pending receipts - user has subsequently disabled read receipts.") - try finder.deleteAllPendingReceipts(transaction: transaction) + finder.deleteAllPendingReceipts(transaction: transaction) return } @@ -156,7 +142,7 @@ public class MessageRequestPendingReceipts: PendingReceiptRecorder { tx: transaction, ) } - try finder.delete(pendingReadReceipts: pendingReadReceipts, transaction: transaction) + finder.delete(pendingReadReceipts: pendingReadReceipts, transaction: transaction) for receipt in pendingViewedReceipts { guard let authorAci = self.authorAci(aciString: receipt.authorAciString, phoneNumber: receipt.authorPhoneNumber, tx: transaction) else { @@ -170,7 +156,7 @@ public class MessageRequestPendingReceipts: PendingReceiptRecorder { tx: transaction, ) } - try finder.delete(pendingViewedReceipts: pendingViewedReceipts, transaction: transaction) + finder.delete(pendingViewedReceipts: pendingViewedReceipts, transaction: transaction) } private func authorAci(aciString: String?, phoneNumber: String?, tx: DBReadTransaction) -> Aci? { @@ -187,13 +173,9 @@ public class MessageRequestPendingReceipts: PendingReceiptRecorder { // MARK: - Persistence -public class PendingReceiptFinder { - public func recordPendingReadReceipt(for message: TSIncomingMessage, thread: TSThread, transaction: DBWriteTransaction) throws { - guard let threadId = thread.sqliteRowId else { - throw OWSAssertionError("threadId was unexpectedly nil") - } - - let record = PendingReadReceiptRecord( +private class PendingReceiptFinder { + func recordPendingReadReceipt(for message: TSIncomingMessage, threadId: TSThread.RowId, transaction: DBWriteTransaction) { + var record = PendingReadReceiptRecord( threadId: threadId, messageTimestamp: Int64(message.timestamp), messageUniqueId: message.uniqueId, @@ -201,15 +183,13 @@ public class PendingReceiptFinder { authorAci: Aci.parseFrom(aciString: message.authorUUID), ) - try record.insert(transaction.database) + failIfThrows { + try record.insert(transaction.database) + } } - public func recordPendingViewedReceipt(for message: TSIncomingMessage, thread: TSThread, transaction: DBWriteTransaction) throws { - guard let threadId = thread.sqliteRowId else { - throw OWSAssertionError("threadId was unexpectedly nil") - } - - let record = PendingViewedReceiptRecord( + func recordPendingViewedReceipt(for message: TSIncomingMessage, threadId: TSThread.RowId, transaction: DBWriteTransaction) { + var record = PendingViewedReceiptRecord( threadId: threadId, messageTimestamp: Int64(message.timestamp), messageUniqueId: message.uniqueId, @@ -217,62 +197,64 @@ public class PendingReceiptFinder { authorAci: Aci.parseFrom(aciString: message.authorUUID), ) - try record.insert(transaction.database) - } - - public func pendingReadReceipts(thread: TSThread, transaction: DBReadTransaction) throws -> [PendingReadReceiptRecord] { - guard let threadId = thread.sqliteRowId else { - throw OWSAssertionError("threadId was unexpectedly nil") + failIfThrows { + try record.insert(transaction.database) } - - let sql = """ - SELECT * FROM pending_read_receipts - WHERE threadId = \(threadId) - """ - return try PendingReadReceiptRecord.fetchAll(transaction.database, sql: sql) } - public func pendingViewedReceipts(thread: TSThread, transaction: DBReadTransaction) throws -> [PendingViewedReceiptRecord] { - guard let threadId = thread.sqliteRowId else { - throw OWSAssertionError("threadId was unexpectedly nil") + func pendingReadReceipts(threadId: TSThread.RowId, transaction: DBReadTransaction) -> [PendingReadReceiptRecord] { + let sql = """ + SELECT * FROM \(PendingReadReceiptRecord.databaseTableName) WHERE threadId = ? + """ + return failIfThrows { + return try PendingReadReceiptRecord.fetchAll(transaction.database, sql: sql, arguments: [threadId]) } - - let sql = """ - SELECT * FROM pending_viewed_receipts - WHERE threadId = \(threadId) - """ - return try PendingViewedReceiptRecord.fetchAll(transaction.database, sql: sql) } - public func threadsWithPendingReceipts(transaction: DBReadTransaction) throws -> [TSThread] { + func pendingViewedReceipts(threadId: TSThread.RowId, transaction: DBReadTransaction) -> [PendingViewedReceiptRecord] { + let sql = """ + SELECT * FROM \(PendingViewedReceiptRecord.databaseTableName) WHERE threadId = ? + """ + return failIfThrows { + return try PendingViewedReceiptRecord.fetchAll(transaction.database, sql: sql, arguments: [threadId]) + } + } + + func threadIdsWithPendingReceipts(transaction: DBReadTransaction) -> Set { let readSql = """ - SELECT DISTINCT model_TSThread.* FROM model_TSThread - INNER JOIN pending_read_receipts - ON pending_read_receipts.threadId = model_TSThread.id + SELECT DISTINCT threadId FROM \(PendingReadReceiptRecord.databaseTableName) """ - let readThreads = try TSThread.grdbFetchCursor(sql: readSql, transaction: transaction).all() + let readThreadIds = failIfThrows { + return try Int64.fetchAll(transaction.database, sql: readSql) + } let viewedSql = """ - SELECT DISTINCT model_TSThread.* FROM model_TSThread - INNER JOIN pending_viewed_receipts - ON pending_viewed_receipts.threadId = model_TSThread.id + SELECT DISTINCT threadId FROM \(PendingViewedReceiptRecord.databaseTableName) """ - let viewedThreads = try TSThread.grdbFetchCursor(sql: viewedSql, transaction: transaction).all() + let viewedThreadIds = failIfThrows { + return try Int64.fetchAll(transaction.database, sql: viewedSql) + } - return Array(Set(readThreads + viewedThreads)) + return Set(readThreadIds + viewedThreadIds) } - public func delete(pendingReadReceipts: [PendingReadReceiptRecord], transaction: DBWriteTransaction) throws { - try PendingReadReceiptRecord.deleteAll(transaction.database, keys: pendingReadReceipts.compactMap { $0.id }) + func delete(pendingReadReceipts: [PendingReadReceiptRecord], transaction: DBWriteTransaction) { + failIfThrows { + try PendingReadReceiptRecord.deleteAll(transaction.database, keys: pendingReadReceipts.compactMap { $0.id }) + } } - public func delete(pendingViewedReceipts: [PendingViewedReceiptRecord], transaction: DBWriteTransaction) throws { - try PendingViewedReceiptRecord.deleteAll(transaction.database, keys: pendingViewedReceipts.compactMap { $0.id }) + func delete(pendingViewedReceipts: [PendingViewedReceiptRecord], transaction: DBWriteTransaction) { + failIfThrows { + try PendingViewedReceiptRecord.deleteAll(transaction.database, keys: pendingViewedReceipts.compactMap { $0.id }) + } } - public func deleteAllPendingReceipts(transaction: DBWriteTransaction) throws { - try PendingReadReceiptRecord.deleteAll(transaction.database) - try PendingViewedReceiptRecord.deleteAll(transaction.database) + func deleteAllPendingReceipts(transaction: DBWriteTransaction) { + failIfThrows { + try PendingReadReceiptRecord.deleteAll(transaction.database) + try PendingViewedReceiptRecord.deleteAll(transaction.database) + } } } diff --git a/SignalServiceKit/Storage/PendingReadReceiptRecord.swift b/SignalServiceKit/Storage/PendingReadReceiptRecord.swift index e32a7b9e9a..2292f85d1b 100644 --- a/SignalServiceKit/Storage/PendingReadReceiptRecord.swift +++ b/SignalServiceKit/Storage/PendingReadReceiptRecord.swift @@ -7,7 +7,7 @@ import Foundation public import GRDB public import LibSignalClient -public struct PendingReadReceiptRecord: Codable, FetchableRecord, PersistableRecord { +public struct PendingReadReceiptRecord: Codable, FetchableRecord, MutablePersistableRecord { public static let databaseTableName = "pending_read_receipts" public private(set) var id: Int64? diff --git a/SignalServiceKit/Storage/PendingViewedReceiptRecord.swift b/SignalServiceKit/Storage/PendingViewedReceiptRecord.swift index 959aa8b56c..7e8c4663aa 100644 --- a/SignalServiceKit/Storage/PendingViewedReceiptRecord.swift +++ b/SignalServiceKit/Storage/PendingViewedReceiptRecord.swift @@ -7,7 +7,7 @@ import Foundation public import GRDB public import LibSignalClient -public struct PendingViewedReceiptRecord: Codable, FetchableRecord, PersistableRecord { +public struct PendingViewedReceiptRecord: Codable, FetchableRecord, MutablePersistableRecord { public static let databaseTableName = "pending_viewed_receipts" public private(set) var id: Int64? diff --git a/SignalServiceKit/tests/Storage/Database/DatabaseRecoveryTest.swift b/SignalServiceKit/tests/Storage/Database/DatabaseRecoveryTest.swift index 18e333c241..f4aa8fb55b 100644 --- a/SignalServiceKit/tests/Storage/Database/DatabaseRecoveryTest.swift +++ b/SignalServiceKit/tests/Storage/Database/DatabaseRecoveryTest.swift @@ -158,7 +158,7 @@ final class DatabaseRecoveryTest: SSKBaseTest { reaction.anyInsert(transaction: transaction) // Pending read receipts (not copied) - let pendingReadReceipt = PendingReadReceiptRecord( + var pendingReadReceipt = PendingReadReceiptRecord( threadId: contactThreadId, messageTimestamp: Int64(message.timestamp), messageUniqueId: message.uniqueId,