Pass RegisteredState through MessageReceiver, instead of LocalIdentifiers

This commit is contained in:
Sasha Weiss 2026-03-30 15:49:56 -07:00 committed by GitHub
parent ec972806cb
commit 4ab0742210
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 61 additions and 66 deletions

View File

@ -26,7 +26,7 @@ public protocol _SentMessageTranscriptReceiver_EarlyMessageManagerShim {
func applyPendingMessages(
for message: TSMessage,
localIdentifiers: LocalIdentifiers,
registeredState: RegisteredState,
tx: DBWriteTransaction,
)
}
@ -39,8 +39,8 @@ public class _SentMessageTranscriptReceiver_EarlyMessageManagerWrapper: _SentMes
self.earlyMessageManager = earlyMessageManager
}
public func applyPendingMessages(for message: TSMessage, localIdentifiers: LocalIdentifiers, tx: DBWriteTransaction) {
earlyMessageManager.applyPendingMessages(for: message, localIdentifiers: localIdentifiers, transaction: tx)
public func applyPendingMessages(for message: TSMessage, registeredState: RegisteredState, tx: DBWriteTransaction) {
earlyMessageManager.applyPendingMessages(for: message, registeredState: registeredState, transaction: tx)
}
}

View File

@ -17,7 +17,7 @@ public protocol SentMessageTranscriptReceiver {
@discardableResult
func process(
_ sentMessageTranscript: SentMessageTranscript,
localIdentifiers: LocalIdentifiers,
registeredState: RegisteredState,
tx: DBWriteTransaction,
) -> Swift.Result<TSOutgoingMessage?, Error>
}

View File

@ -54,7 +54,7 @@ public class SentMessageTranscriptReceiverImpl: SentMessageTranscriptReceiver {
public func process(
_ transcript: SentMessageTranscript,
localIdentifiers: LocalIdentifiers,
registeredState: RegisteredState,
tx: DBWriteTransaction,
) -> Result<TSOutgoingMessage?, Error> {
@ -154,7 +154,7 @@ public class SentMessageTranscriptReceiverImpl: SentMessageTranscriptReceiver {
return .failure(OWSAssertionError("Protocol version validation failed"))
}
updateDisappearingMessageTokenIfNecessary(target: target, localIdentifiers: localIdentifiers, tx: tx)
updateDisappearingMessageTokenIfNecessary(target: target, localIdentifiers: registeredState.localIdentifiers, tx: tx)
return .success(nil)
case .message(let messageParams):
Logger.info("Recording transcript in thread: \(messageParams.target.thread.logString) timestamp: \(transcript.timestamp)")
@ -164,7 +164,7 @@ public class SentMessageTranscriptReceiverImpl: SentMessageTranscriptReceiver {
return self.process(
messageParams: messageParams,
transcript: transcript,
localIdentifiers: localIdentifiers,
registeredState: registeredState,
tx: tx,
).map { $0 }
}
@ -173,14 +173,16 @@ public class SentMessageTranscriptReceiverImpl: SentMessageTranscriptReceiver {
private func process(
messageParams: SentMessageTranscriptType.Message,
transcript: SentMessageTranscript,
localIdentifiers: LocalIdentifiers,
registeredState: RegisteredState,
tx: DBWriteTransaction,
) -> Result<TSOutgoingMessage, Error> {
guard validateProtocolVersion(for: transcript, thread: messageParams.target.thread, tx: tx) else {
return .failure(OWSAssertionError("Protocol version validation failed"))
}
updateDisappearingMessageTokenIfNecessary(target: messageParams.target, localIdentifiers: localIdentifiers, tx: tx)
let localIdentifiers = registeredState.localIdentifiers
updateDisappearingMessageTokenIfNecessary(target: messageParams.target, localIdentifiers: registeredState.localIdentifiers, tx: tx)
let outgoingMessageBuilder = TSOutgoingMessageBuilder(
thread: messageParams.target.thread,
@ -400,7 +402,7 @@ public class SentMessageTranscriptReceiverImpl: SentMessageTranscriptReceiver {
)
}
self.earlyMessageManager.applyPendingMessages(for: outgoingMessage, localIdentifiers: localIdentifiers, tx: tx)
self.earlyMessageManager.applyPendingMessages(for: outgoingMessage, registeredState: registeredState, tx: tx)
if outgoingMessage.isViewOnceMessage {
// Don't download attachments for "view-once" messages from linked devices.

View File

@ -13,7 +13,7 @@ open class SentMessageTranscriptReceiverMock: SentMessageTranscriptReceiver {
public func process(
_: SentMessageTranscript,
localIdentifiers: LocalIdentifiers,
registeredState: RegisteredState,
tx: DBWriteTransaction,
) -> Result<TSOutgoingMessage?, Error> {
// Do nothing

View File

@ -338,10 +338,10 @@ public class EarlyMessageManager {
}
}
public func applyPendingMessages(for message: TSMessage, localIdentifiers: LocalIdentifiers, transaction: DBWriteTransaction) {
public func applyPendingMessages(for message: TSMessage, registeredState: RegisteredState, transaction: DBWriteTransaction) {
let identifier: MessageIdentifier
if let message = message as? TSOutgoingMessage {
identifier = MessageIdentifier(timestamp: message.timestamp, author: localIdentifiers.aci)
identifier = MessageIdentifier(timestamp: message.timestamp, author: registeredState.localIdentifiers.aci)
} else if let message = message as? TSIncomingMessage {
guard let authorAci = Aci.parseFrom(aciString: message.authorUUID) else {
return owsFailDebug("Attempted to apply pending messages for message missing sender aci with type \(message.interactionType) from \(message.authorAddress)")
@ -352,7 +352,7 @@ public class EarlyMessageManager {
return owsFailDebug("attempted to apply pending messages for unsupported message type \(message.interactionType)")
}
applyPendingMessages(for: identifier, localIdentifiers: localIdentifiers, tx: transaction) { earlyReceipt in
applyPendingMessages(for: identifier, registeredState: registeredState, tx: transaction) { earlyReceipt in
switch earlyReceipt {
case .outgoingMessageRead(let sender, let deviceId, let timestamp):
Logger.info("Applying early read receipt from \(sender):\(deviceId) for outgoing message \(identifier)")
@ -432,12 +432,12 @@ public class EarlyMessageManager {
Logger.info("Not processing viewed receipt for system story")
return
}
guard let localIdentifiers = DependenciesBridge.shared.tsAccountManager.localIdentifiers(tx: transaction) else {
guard let registeredState = try? DependenciesBridge.shared.tsAccountManager.registeredState(tx: transaction) else {
owsFailDebug("Can't process messages when not registered.")
return
}
let identifier = MessageIdentifier(timestamp: storyMessage.timestamp, author: storyMessage.authorAci)
applyPendingMessages(for: identifier, localIdentifiers: localIdentifiers, tx: transaction) { earlyReceipt in
applyPendingMessages(for: identifier, registeredState: registeredState, tx: transaction) { earlyReceipt in
switch earlyReceipt {
case .outgoingMessageRead(let sender, let deviceId, _):
owsFailDebug("Unexpectedly received early read receipt from \(sender):\(deviceId) for StoryMessage \(identifier)")
@ -476,7 +476,7 @@ public class EarlyMessageManager {
private func applyPendingMessages(
for identifier: MessageIdentifier,
localIdentifiers: LocalIdentifiers,
registeredState: RegisteredState,
tx transaction: DBWriteTransaction,
earlyReceiptProcessor: (EarlyReceipt) -> Void,
) {
@ -518,7 +518,7 @@ public class EarlyMessageManager {
wasReceivedByUD: earlyEnvelope.wasReceivedByUD,
serverDeliveryTimestamp: earlyEnvelope.serverDeliveryTimestamp,
shouldDiscardVisibleMessages: false,
localIdentifiers: localIdentifiers,
registeredState: registeredState,
tx: transaction,
)
}

View File

@ -346,13 +346,19 @@ class SpecificGroupMessageProcessor {
// Do nothing.
break
case .doNotDiscard, .discardVisibleMessages:
let tsAccountManager = DependenciesBridge.shared.tsAccountManager
guard let registeredState = try? tsAccountManager.registeredState(tx: tx) else {
Logger.warn("Missing registeredState!")
return
}
SSKEnvironment.shared.messageReceiverRef.processEnvelope(
jobInfo.envelope,
plaintextData: jobInfo.plaintextData,
wasReceivedByUD: jobInfo.job.wasReceivedByUD,
serverDeliveryTimestamp: jobInfo.job.serverDeliveryTimestamp,
shouldDiscardVisibleMessages: discardMode == .discardVisibleMessages,
localIdentifiers: DependenciesBridge.shared.tsAccountManager.localIdentifiers(tx: tx)!,
registeredState: registeredState,
tx: tx,
)
}

View File

@ -223,12 +223,12 @@ public class MessageProcessor {
startTime = CACurrentMediaTime()
// This is only called via `drainPendingEnvelopes`, and that confirms that
// we're registered. If we're registered, we must have `LocalIdentifiers`,
// so this (generally) shouldn't fail.
guard let localIdentifiers = DependenciesBridge.shared.tsAccountManager.localIdentifiers(tx: tx) else {
// we're registered. Consequently, this generally shouldn't fail.
let tsAccountManager = DependenciesBridge.shared.tsAccountManager
guard let registeredState = try? tsAccountManager.registeredState(tx: tx) else {
return
}
let localDeviceId = DependenciesBridge.shared.tsAccountManager.storedDeviceId(tx: tx)
let localDeviceId = tsAccountManager.storedDeviceId(tx: tx)
var remainingEnvelopes = batchEnvelopes[...]
while
@ -244,7 +244,7 @@ public class MessageProcessor {
// stop processing envelopes.
let relatedRequests = buildNextCombinedRequest(
envelopes: &remainingEnvelopes,
localIdentifiers: localIdentifiers,
localIdentifiers: registeredState.localIdentifiers,
localDeviceId: localDeviceId,
tx: tx,
)
@ -255,7 +255,7 @@ public class MessageProcessor {
}
handle(
relatedRequests: relatedRequests,
localIdentifiers: localIdentifiers,
registeredState: registeredState,
transaction: tx,
)
}
@ -307,12 +307,12 @@ public class MessageProcessor {
return results
}
private func handle(relatedRequests: [ProcessingRequest], localIdentifiers: LocalIdentifiers, transaction: DBWriteTransaction) {
private func handle(relatedRequests: [ProcessingRequest], registeredState: RegisteredState, transaction: DBWriteTransaction) {
// Efficiently handle delivery receipts for the same message by fetching the sent message only
// once and only using one updateWith... to update the message with new recipient state.
BatchingDeliveryReceiptContext.withDeferredUpdates(transaction: transaction) { context in
for request in relatedRequests {
handleProcessingRequest(request, context: context, localIdentifiers: localIdentifiers, tx: transaction)
handleProcessingRequest(request, context: context, registeredState: registeredState, tx: transaction)
}
}
}
@ -320,7 +320,7 @@ public class MessageProcessor {
private func reallyHandleProcessingRequest(
_ request: ProcessingRequest,
context: DeliveryReceiptContext,
localIdentifiers: LocalIdentifiers,
registeredState: RegisteredState,
transaction: DBWriteTransaction,
) {
switch request.state {
@ -335,7 +335,7 @@ public class MessageProcessor {
)
SSKEnvironment.shared.messageReceiverRef.finishProcessingEnvelope(decryptedEnvelope, tx: transaction)
case .messageReceiverRequest(let messageReceiverRequest):
SSKEnvironment.shared.messageReceiverRef.handleRequest(messageReceiverRequest, context: context, localIdentifiers: localIdentifiers, tx: transaction)
SSKEnvironment.shared.messageReceiverRef.handleRequest(messageReceiverRequest, context: context, registeredState: registeredState, tx: transaction)
SSKEnvironment.shared.messageReceiverRef.finishProcessingEnvelope(messageReceiverRequest.decryptedEnvelope, tx: transaction)
case .clearPlaceholdersOnly(let decryptedEnvelope):
SSKEnvironment.shared.messageReceiverRef.finishProcessingEnvelope(decryptedEnvelope, tx: transaction)
@ -347,10 +347,10 @@ public class MessageProcessor {
private func handleProcessingRequest(
_ request: ProcessingRequest,
context: DeliveryReceiptContext,
localIdentifiers: LocalIdentifiers,
registeredState: RegisteredState,
tx: DBWriteTransaction,
) {
reallyHandleProcessingRequest(request, context: context, localIdentifiers: localIdentifiers, transaction: tx)
reallyHandleProcessingRequest(request, context: context, registeredState: registeredState, transaction: tx)
tx.addSyncCompletion { request.receivedEnvelope.completion() }
}

View File

@ -125,11 +125,11 @@ public final class MessageReceiver {
wasReceivedByUD: Bool,
serverDeliveryTimestamp: UInt64,
shouldDiscardVisibleMessages: Bool,
localIdentifiers: LocalIdentifiers,
registeredState: RegisteredState,
tx: DBWriteTransaction,
) {
do {
let validatedEnvelope = try ValidatedIncomingEnvelope(envelope, localIdentifiers: localIdentifiers)
let validatedEnvelope = try ValidatedIncomingEnvelope(envelope, localIdentifiers: registeredState.localIdentifiers)
switch validatedEnvelope.kind {
case .unidentifiedSender, .identifiedSender:
// At this point, unidentifiedSender envelopes have already been updated
@ -164,7 +164,7 @@ public final class MessageReceiver {
handleRequest(
messageReceiverRequest,
context: PassthroughDeliveryReceiptContext(),
localIdentifiers: localIdentifiers,
registeredState: registeredState,
tx: tx,
)
fallthrough
@ -184,7 +184,7 @@ public final class MessageReceiver {
func handleRequest(
_ request: MessageReceiverRequest,
context: DeliveryReceiptContext,
localIdentifiers: LocalIdentifiers,
registeredState: RegisteredState,
tx: DBWriteTransaction,
) {
let protoContent = request.protoContent
@ -200,13 +200,13 @@ public final class MessageReceiver {
switch request.messageType {
case .syncMessage(let syncMessage):
handleIncomingEnvelope(request: request, syncMessage: syncMessage, localIdentifiers: localIdentifiers, tx: tx)
handleIncomingEnvelope(request: request, syncMessage: syncMessage, registeredState: registeredState, tx: tx)
DependenciesBridge.shared.deviceManager.setHasReceivedSyncMessage(transaction: tx)
case .dataMessage(let dataMessage):
handleIncomingEnvelope(request: request, dataMessage: dataMessage, localIdentifiers: localIdentifiers, tx: tx)
handleIncomingEnvelope(request: request, dataMessage: dataMessage, registeredState: registeredState, tx: tx)
case .callMessage(let callMessage):
owsAssertDebug(!request.shouldDiscardVisibleMessages)
handleIncomingEnvelope(request: request, callMessage: callMessage, localIdentifiers: localIdentifiers, tx: tx)
handleIncomingEnvelope(request: request, callMessage: callMessage, registeredState: registeredState, tx: tx)
case .typingMessage(let typingMessage):
handleIncomingEnvelope(request: request, typingMessage: typingMessage, tx: tx)
case .nullMessage:
@ -216,7 +216,7 @@ public final class MessageReceiver {
case .decryptionErrorMessage(let decryptionErrorMessage):
handleIncomingEnvelope(request: request, decryptionErrorMessage: decryptionErrorMessage, tx: tx)
case .storyMessage(let storyMessage):
handleIncomingEnvelope(request: request, storyMessage: storyMessage, localIdentifiers: localIdentifiers, tx: tx)
handleIncomingEnvelope(request: request, storyMessage: storyMessage, registeredState: registeredState, tx: tx)
case .editMessage(let editMessage):
let result = handleIncomingEnvelope(request: request, editMessage: editMessage, tx: tx)
switch result {
@ -362,9 +362,10 @@ public final class MessageReceiver {
private func handleIncomingEnvelope(
request: MessageReceiverRequest,
syncMessage: SSKProtoSyncMessage,
localIdentifiers: LocalIdentifiers,
registeredState: RegisteredState,
tx: DBWriteTransaction,
) {
let localIdentifiers = registeredState.localIdentifiers
let decryptedEnvelope = request.decryptedEnvelope
guard decryptedEnvelope.sourceAci == localIdentifiers.aci else {
@ -546,10 +547,6 @@ public final class MessageReceiver {
Logger.warn("Received GroupCallUpdate for invalid groupId")
}
} else if let pollTerminate = dataMessage.pollTerminate {
guard let localIdentifiers = DependenciesBridge.shared.tsAccountManager.localIdentifiers(tx: tx) else {
owsFailDebug("Missing local identifiers!")
return
}
do {
let targetMessage = try DependenciesBridge.shared.pollMessageManager.processIncomingPollTerminate(
pollTerminateProto: pollTerminate,
@ -584,10 +581,6 @@ public final class MessageReceiver {
return
}
} else if let pollVote = dataMessage.pollVote {
guard let localIdentifiers = DependenciesBridge.shared.tsAccountManager.localIdentifiers(tx: tx) else {
owsFailDebug("Missing local identifiers!")
return
}
do {
guard
let (targetMessage, _) = try DependenciesBridge.shared.pollMessageManager.processIncomingPollVote(
@ -637,13 +630,9 @@ public final class MessageReceiver {
owsFailDebug("Could not unpin message \(error)")
}
} else {
guard let localIdentifiers = DependenciesBridge.shared.tsAccountManager.localIdentifiers(tx: tx) else {
owsFailDebug("Missing local identifiers!")
return
}
DependenciesBridge.shared.sentMessageTranscriptReceiver.process(
transcript,
localIdentifiers: localIdentifiers,
registeredState: registeredState,
tx: tx,
)
}
@ -938,7 +927,7 @@ public final class MessageReceiver {
private func handleIncomingEnvelope(
request: MessageReceiverRequest,
dataMessage: SSKProtoDataMessage,
localIdentifiers: LocalIdentifiers,
registeredState: RegisteredState,
tx: DBWriteTransaction,
) {
guard SDS.fitsInInt64(dataMessage.timestamp) else {
@ -946,6 +935,7 @@ public final class MessageReceiver {
return
}
let localIdentifiers = registeredState.localIdentifiers
let envelope = request.decryptedEnvelope
if let groupId = self.groupId(for: dataMessage) {
@ -978,7 +968,7 @@ public final class MessageReceiver {
} else if dataMessage.flags & UInt32(SSKProtoDataMessageFlags.profileKeyUpdate.rawValue) != 0 {
// Do nothing, we handle profile keys on all incoming messages above.
} else {
message = processFlaglessDataMessage(dataMessage, request: request, thread: thread, tx: tx)
message = processFlaglessDataMessage(dataMessage, request: request, thread: thread, registeredState: registeredState, tx: tx)
}
// Send delivery receipts for "valid data" messages received via UD.
@ -1069,6 +1059,7 @@ public final class MessageReceiver {
_ dataMessage: SSKProtoDataMessage,
request: MessageReceiverRequest,
thread: TSThread,
registeredState: RegisteredState,
tx: DBWriteTransaction,
) -> TSIncomingMessage? {
let envelope = request.decryptedEnvelope
@ -1676,12 +1667,7 @@ public final class MessageReceiver {
owsAssertDebug(message.insertedMessageHasRenderableContent(rowId: message.sqliteRowId!, tx: tx))
guard let localIdentifiers = DependenciesBridge.shared.tsAccountManager.localIdentifiers(tx: tx) else {
owsFailDebug("Can't process messages when not registered.")
return nil
}
SSKEnvironment.shared.earlyMessageManagerRef.applyPendingMessages(for: message, localIdentifiers: localIdentifiers, transaction: tx)
SSKEnvironment.shared.earlyMessageManagerRef.applyPendingMessages(for: message, registeredState: registeredState, transaction: tx)
// Any messages sent from the current user - from this device or another -
// should be automatically marked as read.
@ -1749,9 +1735,10 @@ public final class MessageReceiver {
private func handleIncomingEnvelope(
request: MessageReceiverRequest,
callMessage: SSKProtoCallMessage,
localIdentifiers: LocalIdentifiers,
registeredState: RegisteredState,
tx: DBWriteTransaction,
) {
let localIdentifiers = registeredState.localIdentifiers
let envelope = request.decryptedEnvelope
// If destinationDevice is defined, ignore messages not addressed to this device.
@ -2057,7 +2044,7 @@ public final class MessageReceiver {
private func handleIncomingEnvelope(
request: MessageReceiverRequest,
storyMessage: SSKProtoStoryMessage,
localIdentifiers: LocalIdentifiers,
registeredState: RegisteredState,
tx: DBWriteTransaction,
) {
do {
@ -2065,7 +2052,7 @@ public final class MessageReceiver {
storyMessage,
timestamp: request.decryptedEnvelope.timestamp,
author: request.decryptedEnvelope.sourceAci,
localIdentifiers: localIdentifiers,
localIdentifiers: registeredState.localIdentifiers,
transaction: tx,
)
} catch {

View File

@ -49,7 +49,7 @@ public class CallMessageRelay {
return
}
guard let localIdentifiers = DependenciesBridge.shared.tsAccountManager.localIdentifiers(tx: transaction) else {
guard let registeredState = try? DependenciesBridge.shared.tsAccountManager.registeredState(tx: transaction) else {
owsFailDebug("Can't process VoIP payload when not registered.")
return
}
@ -71,7 +71,7 @@ public class CallMessageRelay {
wasReceivedByUD: payload.wasReceivedByUD,
serverDeliveryTimestamp: adjustedDeliveryTimestamp,
shouldDiscardVisibleMessages: false,
localIdentifiers: localIdentifiers,
registeredState: registeredState,
tx: transaction,
)
}