diff --git a/SignalServiceKit/Messages/DeliveryReceiptContext.swift b/SignalServiceKit/Messages/DeliveryReceiptContext.swift index 53a5645f06..e74c14d42b 100644 --- a/SignalServiceKit/Messages/DeliveryReceiptContext.swift +++ b/SignalServiceKit/Messages/DeliveryReceiptContext.swift @@ -5,14 +5,12 @@ import Foundation -public protocol DeliveryReceiptContext: AnyObject { +public protocol DeliveryReceiptContext { func addUpdate( message: TSOutgoingMessage, transaction: DBWriteTransaction, update: @escaping (TSOutgoingMessage) -> Void, ) - - func messages(_ timestamps: UInt64, transaction: DBReadTransaction) -> [TSOutgoingMessage] } private struct Update { @@ -20,20 +18,6 @@ private struct Update { let update: (TSOutgoingMessage) -> Void } -private extension TSOutgoingMessage { - static func fetch(_ timestamp: UInt64, transaction: DBReadTransaction) -> [TSOutgoingMessage] { - do { - return try InteractionFinder.fetchInteractions( - timestamp: timestamp, - transaction: transaction, - ).compactMap { $0 as? TSOutgoingMessage } - } catch { - owsFailDebug("Error loading interactions: \(error)") - return [] - } - } -} - public class PassthroughDeliveryReceiptContext: DeliveryReceiptContext { public init() {} @@ -42,25 +26,13 @@ public class PassthroughDeliveryReceiptContext: DeliveryReceiptContext { transaction: DBWriteTransaction, update: @escaping (TSOutgoingMessage) -> Void, ) { - let deferredUpdate = Update(message: message, update: update) - message.anyUpdateOutgoingMessage(transaction: transaction) { message in - deferredUpdate.update(message) - } - } - - public func messages(_ timestamp: UInt64, transaction: DBReadTransaction) -> [TSOutgoingMessage] { - return TSOutgoingMessage.fetch(timestamp, transaction: transaction) + message.anyUpdateOutgoingMessage(transaction: transaction, block: update) } } public class BatchingDeliveryReceiptContext: DeliveryReceiptContext { - private var messages = [UInt64: [TSOutgoingMessage]]() private var deferredUpdates: [Update] = [] -#if TESTABLE_BUILD - static var didRunDeferredUpdates: ((Int, DBWriteTransaction) -> Void)? -#endif - static func withDeferredUpdates(transaction: DBWriteTransaction, _ closure: (DeliveryReceiptContext) -> Void) { let instance = BatchingDeliveryReceiptContext() closure(instance) @@ -77,15 +49,6 @@ public class BatchingDeliveryReceiptContext: DeliveryReceiptContext { deferredUpdates.append(Update(message: message, update: update)) } - public func messages(_ timestamp: UInt64, transaction: DBReadTransaction) -> [TSOutgoingMessage] { - if let result = messages[timestamp] { - return result - } - let fetched = TSOutgoingMessage.fetch(timestamp, transaction: transaction) - messages[timestamp] = fetched - return fetched - } - private struct UpdateCollection { private var message: TSOutgoingMessage? private var closures = [(TSOutgoingMessage) -> Void]() @@ -118,20 +81,12 @@ public class BatchingDeliveryReceiptContext: DeliveryReceiptContext { } private func runDeferredUpdates(transaction: DBWriteTransaction) { + let deferredUpdates = self.deferredUpdates + self.deferredUpdates = [] var updateCollection = UpdateCollection() -#if TESTABLE_BUILD - let count = deferredUpdates.count -#endif - while let update = deferredUpdates.first { - deferredUpdates.removeFirst() + for update in deferredUpdates { updateCollection.addOrExecute(update: update, transaction: transaction) } updateCollection.execute(transaction: transaction) -#if TESTABLE_BUILD - let closure = Self.didRunDeferredUpdates - Self.didRunDeferredUpdates = nil - closure?(count, transaction) -#endif } - } diff --git a/SignalServiceKit/Messages/Interactions/TSOutgoingMessage.swift b/SignalServiceKit/Messages/Interactions/TSOutgoingMessage.swift index 7c730ef1ee..0db1a78fe9 100644 --- a/SignalServiceKit/Messages/Interactions/TSOutgoingMessage.swift +++ b/SignalServiceKit/Messages/Interactions/TSOutgoingMessage.swift @@ -657,6 +657,7 @@ extension TSOutgoingMessage { receiptType: .delivered, receiptTimestamp: timestamp, tryToClearPhoneNumberSharing: true, + context: context, tx: tx, ) } @@ -672,6 +673,7 @@ extension TSOutgoingMessage { deviceId: deviceId, receiptType: .read, receiptTimestamp: timestamp, + context: PassthroughDeliveryReceiptContext(), tx: tx, ) } @@ -687,6 +689,7 @@ extension TSOutgoingMessage { deviceId: deviceId, receiptType: .viewed, receiptTimestamp: timestamp, + context: PassthroughDeliveryReceiptContext(), tx: tx, ) } @@ -711,6 +714,7 @@ extension TSOutgoingMessage { receiptType: IncomingReceiptType, receiptTimestamp: UInt64, tryToClearPhoneNumberSharing: Bool = false, + context: any DeliveryReceiptContext, tx: DBWriteTransaction, ) { owsAssertDebug(recipientAddress.isValid) @@ -734,33 +738,37 @@ extension TSOutgoingMessage { recipientDatabaseTable: DependenciesBridge.shared.recipientDatabaseTable, signalServiceAddressCache: SSKEnvironment.shared.signalServiceAddressCacheRef, ) - anyUpdateOutgoingMessage(transaction: tx) { message in - guard - let recipientState: TSOutgoingMessageRecipientState = { - if let existingMatch = message.recipientAddressStates?[recipientAddress] { - return existingMatch - } - if let normalizedAddress = recipientStateMerger.normalizedAddressIfNeeded(for: recipientAddress, tx: tx) { - // If we get a receipt from a PNI, then normalizing PNIs -> ACIs won't fix - // it, but normalizing the address from a PNI to an ACI might fix it. - return message.recipientAddressStates?[normalizedAddress] - } else { - // If we get a receipt from an ACI, then we might have the PNI stored, and - // we need to migrate it to the ACI before we'll be able to find it. - recipientStateMerger.normalize(&message.recipientAddressStates, tx: tx) - return message.recipientAddressStates?[recipientAddress] - } - }() - else { - owsFailDebug("Missing recipient state for \(recipientAddress)") - return - } + context.addUpdate( + message: self, + transaction: tx, + update: { message in + guard + let recipientState: TSOutgoingMessageRecipientState = { + if let existingMatch = message.recipientAddressStates?[recipientAddress] { + return existingMatch + } + if let normalizedAddress = recipientStateMerger.normalizedAddressIfNeeded(for: recipientAddress, tx: tx) { + // If we get a receipt from a PNI, then normalizing PNIs -> ACIs won't fix + // it, but normalizing the address from a PNI to an ACI might fix it. + return message.recipientAddressStates?[normalizedAddress] + } else { + // If we get a receipt from an ACI, then we might have the PNI stored, and + // we need to migrate it to the ACI before we'll be able to find it. + recipientStateMerger.normalize(&message.recipientAddressStates, tx: tx) + return message.recipientAddressStates?[recipientAddress] + } + }() + else { + owsFailDebug("Missing recipient state for \(recipientAddress)") + return + } - recipientState.updateStatusIfPossible( - receiptType.asRecipientStatus, - statusTimestamp: receiptTimestamp, - ) - } + recipientState.updateStatusIfPossible( + receiptType.asRecipientStatus, + statusTimestamp: receiptTimestamp, + ) + }, + ) } } diff --git a/SignalServiceKit/tests/Messages/DeliveryReceiptContextTests.swift b/SignalServiceKit/tests/Messages/DeliveryReceiptContextTests.swift index b46f28db94..8256fb6cda 100644 --- a/SignalServiceKit/tests/Messages/DeliveryReceiptContextTests.swift +++ b/SignalServiceKit/tests/Messages/DeliveryReceiptContextTests.swift @@ -9,18 +9,15 @@ import XCTest class DeliveryReceiptContextTests: SSKBaseTest { func testExecutesDifferentMessages() throws { let aliceRecipient = SignalServiceAddress(phoneNumber: "+12345678900") - var timestamp: UInt64? - write { transaction in + let message = write { transaction in let aliceContactThread = TSContactThread.getOrCreateThread(withContactAddress: aliceRecipient, transaction: transaction) let helloAlice = TSOutgoingMessage(in: aliceContactThread, messageBody: "Hello Alice") helloAlice.anyInsert(transaction: transaction) - timestamp = helloAlice.timestamp + return helloAlice } - XCTAssertNotNil(timestamp) write { transaction in var messages = [TSOutgoingMessage]() BatchingDeliveryReceiptContext.withDeferredUpdates(transaction: transaction) { context in - let message = context.messages(timestamp!, transaction: transaction)[0] context.addUpdate(message: message, transaction: transaction) { m in messages.append(m) }