Migrate sessions to a dedicated table

This commit is contained in:
Max Radermacher 2025-12-29 16:17:31 -06:00 committed by GitHub
parent 9d37668550
commit 4ae6bfe50d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 713 additions and 527 deletions

View File

@ -572,6 +572,7 @@
500FE4E2288A373100FA090C /* BadgeGiftingAlreadyRedeemedSheet.swift in Sources */ = {isa = PBXBuildFile; fileRef = 500FE4E1288A373100FA090C /* BadgeGiftingAlreadyRedeemedSheet.swift */; };
50101FB22B083C8100C648E4 /* ChatListSettingsButtonState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 50101FB12B083C8100C648E4 /* ChatListSettingsButtonState.swift */; };
50101FB42B08447000C648E4 /* ChatListProxyButtonCreator.swift in Sources */ = {isa = PBXBuildFile; fileRef = 50101FB32B08447000C648E4 /* ChatListProxyButtonCreator.swift */; };
501050BB2EB959A4005161CA /* SessionStore.swift in Sources */ = {isa = PBXBuildFile; fileRef = 501050BA2EB959A4005161CA /* SessionStore.swift */; };
501052642BDAEEDC0097DDC5 /* MobileCoinExternal.pb.swift in Sources */ = {isa = PBXBuildFile; fileRef = 501052632BDAEEDC0097DDC5 /* MobileCoinExternal.pb.swift */; };
501052672BDB22940097DDC5 /* PrivacyInfo.xcprivacy in Resources */ = {isa = PBXBuildFile; fileRef = 501052652BDB15B90097DDC5 /* PrivacyInfo.xcprivacy */; };
501052692BDB232A0097DDC5 /* PrivacyInfo.xcprivacy in Resources */ = {isa = PBXBuildFile; fileRef = 501052682BDB232A0097DDC5 /* PrivacyInfo.xcprivacy */; };
@ -1806,7 +1807,6 @@
C1C4AA3329E7038D000CE9D3 /* EditManagerShims.swift in Sources */ = {isa = PBXBuildFile; fileRef = C1C4AA3229E7038D000CE9D3 /* EditManagerShims.swift */; };
C1C7E4FB2BE0419300F196EE /* UploadMetadata.swift in Sources */ = {isa = PBXBuildFile; fileRef = C1C7E4FA2BE0419300F196EE /* UploadMetadata.swift */; };
C1CA5F8E2BE2F21C00D733CA /* BackupArchiveDistributionListRecipientArchiver.swift in Sources */ = {isa = PBXBuildFile; fileRef = C1CA5F8D2BE2F21C00D733CA /* BackupArchiveDistributionListRecipientArchiver.swift */; };
C1CD0E3A2A6B0D2700307F1A /* SignalSessionStore.swift in Sources */ = {isa = PBXBuildFile; fileRef = C1CD0E392A6B0D2700307F1A /* SignalSessionStore.swift */; };
C1CD0E402A6B37BF00307F1A /* PreKeyStoreImplTest.swift in Sources */ = {isa = PBXBuildFile; fileRef = C1CD0E3F2A6B37BF00307F1A /* PreKeyStoreImplTest.swift */; };
C1CF83D02B96C85E00CDC9C4 /* ChunkedOutputStreamTransform.swift in Sources */ = {isa = PBXBuildFile; fileRef = C1CF83CF2B96C85E00CDC9C4 /* ChunkedOutputStreamTransform.swift */; };
C1CF83D22B9A1FCB00CDC9C4 /* GzipStreamTransform.swift in Sources */ = {isa = PBXBuildFile; fileRef = C1CF83D12B9A1FCB00CDC9C4 /* GzipStreamTransform.swift */; };
@ -3739,7 +3739,6 @@
F9C5CD33289453B300548EEE /* SignedPreKeyStoreImpl.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5CA55289453B100548EEE /* SignedPreKeyStoreImpl.swift */; };
F9C5CD34289453B300548EEE /* SignalProtocolStore.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5CA56289453B100548EEE /* SignalProtocolStore.swift */; };
F9C5CD37289453B300548EEE /* SenderKeyStore.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5CA59289453B100548EEE /* SenderKeyStore.swift */; };
F9C5CD3C289453B300548EEE /* SSKSessionStore.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5CA5E289453B100548EEE /* SSKSessionStore.swift */; };
F9C5CD52289453B300548EEE /* PreKeyStoreImpl.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5CA75289453B100548EEE /* PreKeyStoreImpl.swift */; };
F9C5CD54289453B300548EEE /* SSKKeychainStorage.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5CA77289453B100548EEE /* SSKKeychainStorage.swift */; };
F9C5CD56289453B300548EEE /* PendingViewedReceiptRecord.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5CA79289453B100548EEE /* PendingViewedReceiptRecord.swift */; };
@ -4716,6 +4715,7 @@
500FE4E1288A373100FA090C /* BadgeGiftingAlreadyRedeemedSheet.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = BadgeGiftingAlreadyRedeemedSheet.swift; sourceTree = "<group>"; };
50101FB12B083C8100C648E4 /* ChatListSettingsButtonState.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatListSettingsButtonState.swift; sourceTree = "<group>"; };
50101FB32B08447000C648E4 /* ChatListProxyButtonCreator.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatListProxyButtonCreator.swift; sourceTree = "<group>"; };
501050BA2EB959A4005161CA /* SessionStore.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SessionStore.swift; sourceTree = "<group>"; };
501052632BDAEEDC0097DDC5 /* MobileCoinExternal.pb.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = MobileCoinExternal.pb.swift; sourceTree = "<group>"; };
501052652BDB15B90097DDC5 /* PrivacyInfo.xcprivacy */ = {isa = PBXFileReference; lastKnownFileType = text.xml; path = PrivacyInfo.xcprivacy; sourceTree = "<group>"; };
501052682BDB232A0097DDC5 /* PrivacyInfo.xcprivacy */ = {isa = PBXFileReference; lastKnownFileType = text.xml; path = PrivacyInfo.xcprivacy; sourceTree = "<group>"; };
@ -5963,7 +5963,6 @@
C1C4AA3229E7038D000CE9D3 /* EditManagerShims.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = EditManagerShims.swift; sourceTree = "<group>"; };
C1C7E4FA2BE0419300F196EE /* UploadMetadata.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = UploadMetadata.swift; sourceTree = "<group>"; };
C1CA5F8D2BE2F21C00D733CA /* BackupArchiveDistributionListRecipientArchiver.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = BackupArchiveDistributionListRecipientArchiver.swift; sourceTree = "<group>"; };
C1CD0E392A6B0D2700307F1A /* SignalSessionStore.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SignalSessionStore.swift; sourceTree = "<group>"; };
C1CD0E3F2A6B37BF00307F1A /* PreKeyStoreImplTest.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PreKeyStoreImplTest.swift; sourceTree = "<group>"; };
C1CF83CF2B96C85E00CDC9C4 /* ChunkedOutputStreamTransform.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChunkedOutputStreamTransform.swift; sourceTree = "<group>"; };
C1CF83D12B9A1FCB00CDC9C4 /* GzipStreamTransform.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GzipStreamTransform.swift; sourceTree = "<group>"; };
@ -7934,7 +7933,6 @@
F9C5CA55289453B100548EEE /* SignedPreKeyStoreImpl.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = SignedPreKeyStoreImpl.swift; sourceTree = "<group>"; };
F9C5CA56289453B100548EEE /* SignalProtocolStore.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = SignalProtocolStore.swift; sourceTree = "<group>"; };
F9C5CA59289453B100548EEE /* SenderKeyStore.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = SenderKeyStore.swift; sourceTree = "<group>"; };
F9C5CA5E289453B100548EEE /* SSKSessionStore.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = SSKSessionStore.swift; sourceTree = "<group>"; };
F9C5CA75289453B100548EEE /* PreKeyStoreImpl.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PreKeyStoreImpl.swift; sourceTree = "<group>"; };
F9C5CA77289453B100548EEE /* SSKKeychainStorage.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = SSKKeychainStorage.swift; sourceTree = "<group>"; };
F9C5CA79289453B100548EEE /* PendingViewedReceiptRecord.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PendingViewedReceiptRecord.swift; sourceTree = "<group>"; };
@ -14881,10 +14879,9 @@
50589CDD2E8C44D5003EF42A /* PreKeyStore.swift */,
F9C5CA75289453B100548EEE /* PreKeyStoreImpl.swift */,
F9C5CA59289453B100548EEE /* SenderKeyStore.swift */,
501050BA2EB959A4005161CA /* SessionStore.swift */,
F9C5CA56289453B100548EEE /* SignalProtocolStore.swift */,
C1CD0E392A6B0D2700307F1A /* SignalSessionStore.swift */,
F9C5CA55289453B100548EEE /* SignedPreKeyStoreImpl.swift */,
F9C5CA5E289453B100548EEE /* SSKSessionStore.swift */,
);
path = AxolotlStore;
sourceTree = "<group>";
@ -19328,6 +19325,7 @@
F9C5CC96289453B300548EEE /* SessionRecord.pb.swift in Sources */,
725465242BA017D500EABFD2 /* SessionResetJob.swift in Sources */,
D9AE0AD729187A700063488B /* SessionResetJobRecord.swift in Sources */,
501050BB2EB959A4005161CA /* SessionStore.swift in Sources */,
1700E34128BD41150073D949 /* SetAlgebra+SSK.swift in Sources */,
502346772DB039320029DB97 /* SetDeque.swift in Sources */,
66C2B14D2A13E2C7008DDE72 /* SgxWebsocketConfigurator.swift in Sources */,
@ -19352,7 +19350,6 @@
F9C5CC9A289453B300548EEE /* SignalService.pb.swift in Sources */,
F9C5CCE2289453B300548EEE /* SignalServiceAddress.swift in Sources */,
F9C5CDBB289453B400548EEE /* SignalServiceProfile.swift in Sources */,
C1CD0E3A2A6B0D2700307F1A /* SignalSessionStore.swift in Sources */,
72B0C2422C9EED0E00B57DAD /* SignedPreKeyRecord.swift in Sources */,
F9C5CD33289453B300548EEE /* SignedPreKeyStoreImpl.swift in Sources */,
F9C5CC51289453B300548EEE /* SMKError.swift in Sources */,
@ -19371,7 +19368,6 @@
F9C5CCA4289453B300548EEE /* SSKProto+OWS.swift in Sources */,
F9C5CCA1289453B300548EEE /* SSKProto.swift in Sources */,
F9C5CC8E289453B300548EEE /* SSKProtos.swift in Sources */,
F9C5CD3C289453B300548EEE /* SSKSessionStore.swift in Sources */,
F9C5CD9E289453B400548EEE /* SSKWebSocket.swift in Sources */,
F9C5CC1B289453B300548EEE /* StickerError.swift in Sources */,
F9C5CC13289453B300548EEE /* StickerInfo.m in Sources */,

View File

@ -23,13 +23,13 @@ class DebugUISessionState: DebugUIPage {
OWSTableItem(title: "Delete All Sessions", actionBlock: {
SSKEnvironment.shared.databaseStorageRef.write { transaction in
let sessionStore = DependenciesBridge.shared.signalProtocolStoreManager.signalProtocolStore(for: .aci).sessionStore
sessionStore.deleteAllSessions(for: contactThread.contactAddress.serviceId!, tx: transaction)
sessionStore.deleteSessions(forServiceId: contactThread.contactAddress.serviceId!, tx: transaction)
}
}),
OWSTableItem(title: "Archive All Sessions", actionBlock: {
SSKEnvironment.shared.databaseStorageRef.write { transaction in
let sessionStore = DependenciesBridge.shared.signalProtocolStoreManager.signalProtocolStore(for: .aci).sessionStore
sessionStore.archiveAllSessions(for: contactThread.contactAddress.serviceId!, tx: transaction)
sessionStore.archiveSessions(forServiceId: contactThread.contactAddress.serviceId!, tx: transaction)
}
}),
]

View File

@ -135,8 +135,8 @@ public class ConversationInternalViewController: OWSTableViewController2 {
let sessionSection = OWSTableSection()
sessionSection.add(.actionItem(withText: "Delete Session") {
SSKEnvironment.shared.databaseStorageRef.write { transaction in
let aciStore = DependenciesBridge.shared.signalProtocolStoreManager.signalProtocolStore(for: .aci)
aciStore.sessionStore.deleteAllSessions(for: address.serviceId!, tx: transaction)
let sessionStore = DependenciesBridge.shared.signalProtocolStoreManager.signalProtocolStore(for: .aci).sessionStore
sessionStore.deleteSessions(forServiceId: address.serviceId!, tx: transaction)
}
})

View File

@ -66,6 +66,7 @@ public class ProvisioningCoordinatorTest: XCTestCase {
self.tsAccountManagerMock = .init()
self.udManagerMock = .init()
let preKeyStore = PreKeyStore()
let sessionStore = SignalServiceKit.SessionStore()
self.provisioningCoordinator = ProvisioningCoordinatorImpl(
chatConnectionManager: chatConnectionManagerMock,
@ -81,9 +82,10 @@ public class ProvisioningCoordinatorTest: XCTestCase {
registrationStateChangeManager: registrationStateChangeManagerMock,
registrationWebSocketManager: MockRegistrationWebSocketManager(),
signalProtocolStoreManager: SignalProtocolStoreManager(
aciProtocolStore: .mock(identity: .aci, preKeyStore: preKeyStore),
pniProtocolStore: .mock(identity: .pni, preKeyStore: preKeyStore),
aciProtocolStore: .mock(identity: .aci, preKeyStore: preKeyStore, recipientIdFinder: recipientIdFinder, sessionStore: sessionStore),
pniProtocolStore: .mock(identity: .pni, preKeyStore: preKeyStore, recipientIdFinder: recipientIdFinder, sessionStore: sessionStore),
preKeyStore: preKeyStore,
sessionStore: sessionStore,
),
signalService: signalServiceMock,
storageServiceManager: storageServiceManagerMock,

View File

@ -211,8 +211,7 @@ public class RegistrationStateChangeManagerImpl: RegistrationStateChangeManager
tx: tx
)
signalProtocolStoreManager.signalProtocolStore(for: .aci).sessionStore.resetSessionStore(tx: tx)
signalProtocolStoreManager.signalProtocolStore(for: .pni).sessionStore.resetSessionStore(tx: tx)
signalProtocolStoreManager.sessionStore.deleteAllSessions(tx: tx)
senderKeyStore.resetSenderKeyStore(transaction: tx)
udManager.removeSenderCertificates(tx: tx)
versionedProfiles.clearProfileKeyCredentials(tx: tx)

View File

@ -94,13 +94,13 @@ struct MergedRecipient {
}
class RecipientMergerImpl: RecipientMerger {
private let aciSessionStore: SignalSessionStore
private let blockedRecipientStore: BlockedRecipientStore
private let identityManager: OWSIdentityManager
private let observers: Observers
private let recipientDatabaseTable: RecipientDatabaseTable
private let recipientFetcher: RecipientFetcher
private let searchableNameIndexer: any SearchableNameIndexer
private let sessionStore: SessionStore
private let storageServiceManager: StorageServiceManager
private let storyRecipientStore: StoryRecipientStore
@ -111,23 +111,23 @@ class RecipientMergerImpl: RecipientMerger {
/// which we learned about the new association, and they are notified in the
/// order in which they are provided.
init(
aciSessionStore: SignalSessionStore,
blockedRecipientStore: BlockedRecipientStore,
identityManager: OWSIdentityManager,
observers: Observers,
recipientDatabaseTable: RecipientDatabaseTable,
recipientFetcher: RecipientFetcher,
searchableNameIndexer: any SearchableNameIndexer,
sessionStore: SessionStore,
storageServiceManager: StorageServiceManager,
storyRecipientStore: StoryRecipientStore
) {
self.aciSessionStore = aciSessionStore
self.blockedRecipientStore = blockedRecipientStore
self.identityManager = identityManager
self.observers = observers
self.recipientDatabaseTable = recipientDatabaseTable
self.recipientFetcher = recipientFetcher
self.searchableNameIndexer = searchableNameIndexer
self.sessionStore = sessionStore
self.storageServiceManager = storageServiceManager
self.storyRecipientStore = storyRecipientStore
}
@ -801,7 +801,7 @@ class RecipientMergerImpl: RecipientMerger {
for affectedRecipient in affectedRecipients {
if affectedRecipient.isEmpty {
// TODO: Should we clean up any more state related to the discarded recipient?
aciSessionStore.mergeRecipient(affectedRecipient, into: mergedRecipient, tx: tx)
sessionStore.mergeRecipientId(affectedRecipient.id, into: mergedRecipient.id, localIdentity: .aci, tx: tx)
identityManager.mergeRecipient(affectedRecipient, into: mergedRecipient, tx: tx)
blockedRecipientStore.mergeRecipientId(affectedRecipient.id, into: mergedRecipient.id, tx: tx)
failIfThrows { try storyRecipientStore.mergeRecipient(affectedRecipient, into: mergedRecipient, tx: tx) }
@ -860,7 +860,7 @@ class RecipientMergerImpl: RecipientMerger {
intoValue: newRecipient.isEmpty ? mergedRecipient : newRecipient
)
guard aciSessionStore.mightContainSession(for: recipientPair.fromValue, tx: tx) else {
guard sessionStore.hasSessionRecords(forRecipientId: recipientPair.fromValue.id, localIdentity: .aci, tx: tx) else {
continue
}
@ -894,7 +894,7 @@ class RecipientMergerImpl: RecipientMerger {
// the session/identity for these recipients.
if recipientPair.fromValue.uniqueId == recipientPair.intoValue.uniqueId {
identityManager.removeRecipientIdentity(for: recipientPair.fromValue.uniqueId, tx: tx)
aciSessionStore.deleteAllSessions(for: recipientPair.fromValue.uniqueId, tx: tx)
sessionStore.deleteSessions(forRecipientId: recipientPair.fromValue.id, localIdentity: .aci, tx: tx)
}
// The canonical case is adding an ACI to a recipient that already had a

View File

@ -527,7 +527,6 @@ public class SentMessageTranscriptReceiverImpl: SentMessageTranscriptReceiver {
}
private func archiveSessions(for address: SignalServiceAddress, tx: DBWriteTransaction) {
let sessionStore = signalProtocolStoreManager.signalProtocolStore(for: .aci).sessionStore
sessionStore.archiveAllSessions(for: address, tx: tx)
self.signalProtocolStoreManager.signalProtocolStore(for: .aci).sessionStore.archiveSessions(forAddress: address, tx: tx)
}
}

View File

@ -276,12 +276,14 @@ extension AppSetup.GlobalsContinuation {
)
let preKeyStore = PreKeyStore()
let sessionStore = SessionStore()
let aciProtocolStore = SignalProtocolStore.build(
dateProvider: dateProvider,
identity: .aci,
preKeyStore: preKeyStore,
recipientIdFinder: recipientIdFinder,
sessionStore: sessionStore,
)
let blockedRecipientStore = BlockedRecipientStore()
let blockingManager = BlockingManager(
@ -305,6 +307,7 @@ extension AppSetup.GlobalsContinuation {
identity: .pni,
preKeyStore: preKeyStore,
recipientIdFinder: recipientIdFinder,
sessionStore: sessionStore,
)
let profileManager = testDependencies.profileManager ?? OWSProfileManager(
appReadiness: appReadiness,
@ -320,6 +323,7 @@ extension AppSetup.GlobalsContinuation {
aciProtocolStore: aciProtocolStore,
pniProtocolStore: pniProtocolStore,
preKeyStore: preKeyStore,
sessionStore: sessionStore,
)
let signalService = testDependencies.signalService ?? OWSSignalService(libsignalNet: libsignalNet)
let signalServiceAddressCache = SignalServiceAddressCache()
@ -764,7 +768,6 @@ extension AppSetup.GlobalsContinuation {
let badgeCountFetcher = BadgeCountFetcherImpl()
let identityManager = OWSIdentityManagerImpl(
aciProtocolStore: aciProtocolStore,
appReadiness: appReadiness,
db: db,
messageSenderJobQueue: messageSenderJobQueue,
@ -775,6 +778,7 @@ extension AppSetup.GlobalsContinuation {
recipientDatabaseTable: recipientDatabaseTable,
recipientFetcher: recipientFetcher,
recipientIdFinder: recipientIdFinder,
sessionStore: sessionStore,
storageServiceManager: storageServiceManager,
tsAccountManager: tsAccountManager
)
@ -1036,7 +1040,6 @@ extension AppSetup.GlobalsContinuation {
let authorMergeHelper = AuthorMergeHelper()
let recipientMerger = RecipientMergerImpl(
aciSessionStore: aciProtocolStore.sessionStore,
blockedRecipientStore: blockedRecipientStore,
identityManager: identityManager,
observers: RecipientMergerImpl.buildObservers(
@ -1063,6 +1066,7 @@ extension AppSetup.GlobalsContinuation {
recipientDatabaseTable: recipientDatabaseTable,
recipientFetcher: recipientFetcher,
searchableNameIndexer: searchableNameIndexer,
sessionStore: sessionStore,
storageServiceManager: storageServiceManager,
storyRecipientStore: storyRecipientStore
)

View File

@ -71,6 +71,11 @@ public enum BuildFlags {
// that's now dead because this is false.
public static let decodeDeprecatedPreKeys = true
// Turn this off after all still-registered clients have run this
// migration. That should happen by 2026-08-04. Then, delete all the code
// that's now dead because this is false.
public static let migrateDeprecatedSessions = true
public static let serviceIdBinaryProvisioning = true
public static let serviceIdBinaryConstantOverhead = !serviceIdStrings || (build <= .internal)
public static let serviceIdBinaryVariableOverhead = !serviceIdStrings || (build <= .dev)

View File

@ -106,6 +106,6 @@ private class SessionResetJobRunner: JobRunner {
private func archiveAllSessions(for contactThread: TSContactThread, tx: DBWriteTransaction) {
let sessionStore = DependenciesBridge.shared.signalProtocolStoreManager.signalProtocolStore(for: .aci).sessionStore
sessionStore.archiveAllSessions(for: contactThread.contactAddress, tx: tx)
sessionStore.archiveSessions(forAddress: contactThread.contactAddress, tx: tx)
}
}

View File

@ -1907,7 +1907,7 @@ public final class MessageReceiver {
let sessionRecord = try sessionStore.loadSession(for: protocolAddress, context: tx)
if try sessionRecord?.currentRatchetKeyMatches(ratchetKey) == true {
Logger.info("Decryption error included ratchet key. Archiving...")
sessionStore.archiveSession(for: sourceAci, deviceId: sourceDeviceId, tx: tx)
sessionStore.archiveSession(forServiceId: sourceAci, deviceId: sourceDeviceId, tx: tx)
didPerformSessionReset = true
} else {
didPerformSessionReset = false
@ -2188,7 +2188,7 @@ public final class MessageReceiver {
TSInfoMessage(thread: thread, messageType: .typeRemoteUserEndedSession).anyInsert(transaction: tx)
let sessionStore = DependenciesBridge.shared.signalProtocolStoreManager.signalProtocolStore(for: .aci).sessionStore
sessionStore.archiveAllSessions(for: decryptedEnvelope.sourceAci, tx: tx)
sessionStore.archiveSessions(forServiceId: decryptedEnvelope.sourceAci, tx: tx)
}
}

View File

@ -48,10 +48,10 @@ public class MessageSender {
// MARK: - Creating Signal Protocol Sessions
private func validSession(for serviceId: ServiceId, deviceId: DeviceId, tx: DBReadTransaction) throws -> SessionRecord? {
private func validSession(for serviceId: ServiceId, deviceId: DeviceId, tx: DBReadTransaction) throws -> LibSignalClient.SessionRecord? {
let sessionStore = DependenciesBridge.shared.signalProtocolStoreManager.signalProtocolStore(for: .aci).sessionStore
do {
guard let session = try sessionStore.loadSession(for: serviceId, deviceId: deviceId, tx: tx) else {
guard let session = try sessionStore.loadSession(forServiceId: serviceId, deviceId: deviceId, tx: tx) else {
return nil
}
guard session.hasCurrentState else {
@ -1555,7 +1555,7 @@ public class MessageSender {
Logger.warn("Found identity key mismatch on outgoing message to \(serviceId).\(deviceId). Archiving session before retrying...")
let signalProtocolStoreManager = DependenciesBridge.shared.signalProtocolStoreManager
let aciSessionStore = signalProtocolStoreManager.signalProtocolStore(for: .aci).sessionStore
aciSessionStore.archiveSession(for: serviceId, deviceId: deviceId, tx: tx)
aciSessionStore.archiveSession(forServiceId: serviceId, deviceId: deviceId, tx: tx)
throw OWSRetryableMessageSenderError()
} catch SignalError.untrustedIdentity {
Logger.warn("Found untrusted identity on outgoing message to \(serviceId). Wrapping error and throwing...")
@ -1755,7 +1755,7 @@ public class MessageSender {
Logger.warn("Stale devices for \(serviceId): \(staleDevices)")
let sessionStore = DependenciesBridge.shared.signalProtocolStoreManager.signalProtocolStore(for: .aci).sessionStore
for staleDeviceId in staleDevices {
sessionStore.archiveSession(for: serviceId, deviceId: staleDeviceId, tx: tx)
sessionStore.archiveSession(forServiceId: serviceId, deviceId: staleDeviceId, tx: tx)
}
}
@ -1809,7 +1809,7 @@ public class MessageSender {
Logger.info("Archiving sessions for extra devices: \(devicesToRemove)")
let sessionStore = DependenciesBridge.shared.signalProtocolStoreManager.signalProtocolStore(for: .aci).sessionStore
for deviceId in devicesToRemove {
sessionStore.archiveSession(for: serviceId, deviceId: deviceId, tx: tx)
sessionStore.archiveSession(forServiceId: serviceId, deviceId: deviceId, tx: tx)
}
}
}

View File

@ -189,7 +189,6 @@ private extension OWSIdentity {
}
public class OWSIdentityManagerImpl: OWSIdentityManager {
private let aciProtocolStore: SignalProtocolStore
private let appReadiness: AppReadiness
private let db: any DB
private let messageSenderJobQueue: MessageSenderJobQueue
@ -202,12 +201,12 @@ public class OWSIdentityManagerImpl: OWSIdentityManager {
private let recipientDatabaseTable: RecipientDatabaseTable
private let recipientFetcher: RecipientFetcher
private let recipientIdFinder: RecipientIdFinder
private let sessionStore: SessionStore
private let shareMyPhoneNumberStore: KeyValueStore
private let storageServiceManager: StorageServiceManager
private let tsAccountManager: TSAccountManager
public init(
aciProtocolStore: SignalProtocolStore,
init(
appReadiness: AppReadiness,
db: any DB,
messageSenderJobQueue: MessageSenderJobQueue,
@ -218,10 +217,10 @@ public class OWSIdentityManagerImpl: OWSIdentityManager {
recipientDatabaseTable: RecipientDatabaseTable,
recipientFetcher: RecipientFetcher,
recipientIdFinder: RecipientIdFinder,
sessionStore: SessionStore,
storageServiceManager: StorageServiceManager,
tsAccountManager: TSAccountManager
) {
self.aciProtocolStore = aciProtocolStore
self.appReadiness = appReadiness
self.db = db
self.messageSenderJobQueue = messageSenderJobQueue
@ -238,6 +237,7 @@ public class OWSIdentityManagerImpl: OWSIdentityManager {
self.recipientDatabaseTable = recipientDatabaseTable
self.recipientFetcher = recipientFetcher
self.recipientIdFinder = recipientIdFinder
self.sessionStore = sessionStore
self.shareMyPhoneNumberStore = KeyValueStore(
collection: "OWSIdentityManager.shareMyPhoneNumberStore"
)
@ -348,27 +348,27 @@ public class OWSIdentityManagerImpl: OWSIdentityManager {
@discardableResult
public func saveIdentityKey(_ identityKey: Data, for serviceId: ServiceId, tx: DBWriteTransaction) -> Result<IdentityChange, RecipientIdError> {
let recipientIdResult = recipientIdFinder.ensureRecipientUniqueId(for: serviceId, tx: tx)
return recipientIdResult.map({ _saveIdentityKey(identityKey, for: serviceId, recipientUniqueId: $0, tx: tx) })
let recipientResult = recipientIdFinder.ensureRecipient(for: serviceId, tx: tx)
return recipientResult.map({ _saveIdentityKey(identityKey, for: serviceId, recipient: $0, tx: tx) })
}
private func _saveIdentityKey(_ identityKey: Data, for serviceId: ServiceId, recipientUniqueId: RecipientUniqueId, tx: DBWriteTransaction) -> IdentityChange {
private func _saveIdentityKey(_ identityKey: Data, for serviceId: ServiceId, recipient: SignalRecipient, tx: DBWriteTransaction) -> IdentityChange {
owsAssertDebug(identityKey.count == Constants.storedIdentityKeyLength)
let existingIdentity = OWSRecipientIdentity.anyFetch(uniqueId: recipientUniqueId, transaction: tx)
let existingIdentity = OWSRecipientIdentity.anyFetch(uniqueId: recipient.uniqueId, transaction: tx)
guard let existingIdentity else {
Logger.info("Saving first-use identity for \(serviceId)")
OWSRecipientIdentity(
uniqueId: recipientUniqueId,
uniqueId: recipient.uniqueId,
identityKey: identityKey,
isFirstKnownKey: true,
createdAt: Date(),
verificationState: .default
).anyInsert(transaction: tx)
// Cancel any pending verification state sync messages for this recipient.
clearSyncMessage(for: recipientUniqueId, tx: tx)
clearSyncMessage(for: recipient.uniqueId, tx: tx)
fireIdentityStateChangeNotification(after: tx)
storageServiceManager.recordPendingUpdates(updatedRecipientUniqueIds: [recipientUniqueId])
storageServiceManager.recordPendingUpdates(updatedRecipientUniqueIds: [recipient.uniqueId])
return .newOrUnchanged
}
@ -386,16 +386,16 @@ public class OWSIdentityManagerImpl: OWSIdentityManager {
Logger.info("Saving new identity for \(serviceId): \(existingIdentity.verificationState) -> \(verificationState)")
insertIdentityChangeInfoMessage(for: serviceId, wasIdentityVerified: existingIdentity.wasIdentityVerified, tx: tx)
OWSRecipientIdentity(
uniqueId: recipientUniqueId,
uniqueId: recipient.uniqueId,
identityKey: identityKey,
isFirstKnownKey: false,
createdAt: Date(),
verificationState: verificationState.rawValue
).anyUpsert(transaction: tx)
aciProtocolStore.sessionStore.archiveAllSessions(for: serviceId, tx: tx)
sessionStore.archiveSessions(forRecipientId: recipient.id, localIdentity: .aci, tx: tx)
// Cancel any pending verification state sync messages for this recipient.
clearSyncMessage(for: recipientUniqueId, tx: tx)
storageServiceManager.recordPendingUpdates(updatedRecipientUniqueIds: [recipientUniqueId])
clearSyncMessage(for: recipient.uniqueId, tx: tx)
storageServiceManager.recordPendingUpdates(updatedRecipientUniqueIds: [recipient.uniqueId])
return .replacedExisting
}

View File

@ -364,8 +364,7 @@ public class OWSMessageDecrypter {
Logger.warn("Archiving session for undecryptable message from \(senderId)")
let sessionStore = DependenciesBridge.shared.signalProtocolStoreManager.signalProtocolStore(for: .aci).sessionStore
sessionStore.archiveSession(for: sourceAci, deviceId: sourceDeviceId, tx: transaction)
sessionStore.archiveSession(forServiceId: sourceAci, deviceId: sourceDeviceId, tx: transaction)
trySendNullMessage(in: contactThread, senderId: senderId, transaction: transaction)
return true
} else {

View File

@ -77,7 +77,7 @@ fileprivate extension SMKMessageType {
@objc
public class SMKSecretSessionCipher: NSObject {
private let currentSessionStore: SessionStore
private let currentSessionStore: LibSignalClient.SessionStore
private let currentPreKeyStore: LibSignalClient.PreKeyStore
private let currentSignedPreKeyStore: SignedPreKeyStore
private let currentKyberPreKeyStore: KyberPreKeyStore
@ -86,7 +86,7 @@ public class SMKSecretSessionCipher: NSObject {
// public SecretSessionCipher(SignalProtocolStore signalProtocolStore) {
init(
sessionStore: SessionStore,
sessionStore: LibSignalClient.SessionStore,
preKeyStore: LibSignalClient.PreKeyStore,
signedPreKeyStore: SignedPreKeyStore,
kyberPreKeyStore: KyberPreKeyStore,

View File

@ -9,9 +9,9 @@ import Foundation
import LibSignalClient
extension SignalProtocolStore {
static func mock(identity: OWSIdentity, preKeyStore: PreKeyStore) -> Self {
static func mock(identity: OWSIdentity, preKeyStore: PreKeyStore, recipientIdFinder: RecipientIdFinder, sessionStore: SessionStore) -> Self {
return SignalProtocolStore(
sessionStore: MockSessionStore(),
sessionStore: SessionManagerForIdentity(identity: identity, recipientIdFinder: recipientIdFinder, sessionStore: sessionStore),
preKeyStore: PreKeyStoreImpl(for: identity, preKeyStore: preKeyStore),
signedPreKeyStore: SignedPreKeyStoreImpl(for: identity, preKeyStore: preKeyStore),
kyberPreKeyStore: KyberPreKeyStoreImpl(for: identity, dateProvider: Date.provider, preKeyStore: preKeyStore),
@ -19,20 +19,4 @@ extension SignalProtocolStore {
}
}
class MockSessionStore: SignalSessionStore {
func mightContainSession(for recipient: SignalRecipient, tx: DBReadTransaction) -> Bool { false }
func mergeRecipient(_ recipient: SignalRecipient, into targetRecipient: SignalRecipient, tx: DBWriteTransaction) { }
func archiveAllSessions(for serviceId: ServiceId, tx: DBWriteTransaction) { }
func archiveAllSessions(for address: SignalServiceAddress, tx: DBWriteTransaction) { }
func archiveSession(for serviceId: ServiceId, deviceId: DeviceId, tx: DBWriteTransaction) { }
func loadSession(for serviceId: ServiceId, deviceId: DeviceId, tx: DBReadTransaction) throws -> LibSignalClient.SessionRecord? { nil }
func loadSession(for address: ProtocolAddress, context: StoreContext) throws -> LibSignalClient.SessionRecord? { nil }
func resetSessionStore(tx: DBWriteTransaction) { }
func deleteAllSessions(for serviceId: ServiceId, tx: DBWriteTransaction) { }
func deleteAllSessions(for recipientUniqueId: RecipientUniqueId, tx: DBWriteTransaction) { }
func removeAll(tx: DBWriteTransaction) { }
func loadExistingSessions(for addresses: [ProtocolAddress], context: StoreContext) throws -> [LibSignalClient.SessionRecord] { [] }
func storeSession(_ record: LibSignalClient.SessionRecord, for address: ProtocolAddress, context: StoreContext) throws { }
}
#endif

View File

@ -1,311 +0,0 @@
//
// Copyright 2020 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//
public import LibSignalClient
public final class SSKSessionStore: SignalSessionStore {
// Note that even though the values here are always serialized Data,
// using AnyObject here defers any checking or conversion of the values
// when converting from an NSDictionary.
fileprivate typealias SessionsByDeviceDictionary = [Int32: AnyObject]
private let keyValueStore: KeyValueStore
private let recipientIdFinder: RecipientIdFinder
public init(
for identity: OWSIdentity,
recipientIdFinder: RecipientIdFinder
) {
self.keyValueStore = KeyValueStore(collection: {
switch identity {
case .aci:
return "TSStorageManagerSessionStoreCollection"
case .pni:
return "TSStorageManagerPNISessionStoreCollection"
}
}())
self.recipientIdFinder = recipientIdFinder
}
fileprivate func loadSerializedSession(
for serviceId: ServiceId,
deviceId: UInt32,
tx: DBReadTransaction
) throws -> Data? {
switch recipientIdFinder.recipientUniqueId(for: serviceId, tx: tx) {
case .none:
return nil
case .some(.success(let recipientUniqueId)):
return loadSerializedSession(for: recipientUniqueId, deviceId: deviceId, tx: tx)
case .some(.failure(let error)):
switch error {
case .mustNotUsePniBecauseAciExists:
throw error
}
}
}
private func serializedSession(fromDatabaseRepresentation entry: Any) -> Data? {
switch entry {
case let data as Data:
return data
default:
owsFailDebug("unexpected entry in session store: \(type(of: entry))")
return nil
}
}
private func loadSerializedSession(
for recipientUniqueId: String,
deviceId: UInt32,
tx: DBReadTransaction
) -> Data? {
owsAssertDebug(!recipientUniqueId.isEmpty)
owsAssertDebug(deviceId > 0)
let dictionary = loadAllSerializedSessions(for: recipientUniqueId, tx: tx)
guard let entry = dictionary?[Int32(bitPattern: deviceId)] else {
return nil
}
return serializedSession(fromDatabaseRepresentation: entry)
}
private func loadAllSerializedSessions(
for recipientUniqueId: String,
tx: DBReadTransaction
) -> SessionsByDeviceDictionary? {
owsAssertDebug(!recipientUniqueId.isEmpty)
guard let serialized = keyValueStore.getData(recipientUniqueId, transaction: tx) else {
return nil
}
let rawDictionary: NSDictionary?
do {
rawDictionary = try NSKeyedUnarchiver.unarchivedObject(ofClasses: [NSDictionary.self, NSNumber.self, NSData.self], from: serialized) as? NSDictionary
} catch let error as NSError {
// Deliberately don't log the full error; it might contain session data.
Logger.error("Unknown data (or legacy session) in session store; continuing as if there were no stored sessions (\(error.domain) \(error.code))")
return nil
}
guard let dictionary = rawDictionary as? SessionsByDeviceDictionary else {
Logger.error("Invalid device ID keys in session store; continuing as if there were no stored sessions")
return nil
}
return dictionary
}
fileprivate func storeSerializedSession(
_ sessionData: Data,
for serviceId: ServiceId,
deviceId: UInt32,
tx: DBWriteTransaction
) throws {
switch recipientIdFinder.ensureRecipientUniqueId(for: serviceId, tx: tx) {
case .failure(let error):
switch error {
case .mustNotUsePniBecauseAciExists:
throw error
}
case .success(let recipientUniqueId):
storeSerializedSession(for: recipientUniqueId, deviceId: deviceId, sessionData: sessionData, tx: tx)
}
}
private func storeSerializedSession(
for recipientUniqueId: String,
deviceId: UInt32,
sessionData: Data,
tx: DBWriteTransaction
) {
owsAssertDebug(!recipientUniqueId.isEmpty)
owsAssertDebug(deviceId > 0)
var dictionary = loadAllSerializedSessions(for: recipientUniqueId, tx: tx) ?? [:]
dictionary[Int32(bitPattern: deviceId)] = sessionData as NSData
saveSerializedSessions(dictionary, for: recipientUniqueId, tx: tx)
}
private func saveSerializedSessions(
_ sessions: SessionsByDeviceDictionary,
for recipientUniqueId: String,
tx: DBWriteTransaction
) {
// Avoid using KeyValueStore.setObject(_:key:transaction:).
// The database-based KV store implicitly archives using NSKeyedArchiver,
// but the in-memory one for testing does not.
// In order for loadAllSerializedSessions(for:tx:) to manually control deserialization,
// we need to consistently archive.
// This will also make it easier to potentially move away from NSKeyedArchiver in the future.
do {
let archived = try NSKeyedArchiver.archivedData(withRootObject: sessions, requiringSecureCoding: true)
keyValueStore.setData(archived, key: recipientUniqueId, transaction: tx)
} catch {
Logger.debug("failed to serialize session data: \(error)\n\(sessions)")
owsFailDebug("failed to serialize session data")
// At least clear out whatever's in the store, so we don't keep old sessions around longer than we should.
keyValueStore.setData(nil, key: recipientUniqueId, transaction: tx)
}
}
public func mightContainSession(for recipient: SignalRecipient, tx: DBReadTransaction) -> Bool {
return keyValueStore.hasValue(recipient.uniqueId, transaction: tx)
}
public func mergeRecipient(_ recipient: SignalRecipient, into targetRecipient: SignalRecipient, tx: DBWriteTransaction) {
let recipientPair = MergePair(fromValue: recipient, intoValue: targetRecipient)
let sessionBlob = recipientPair.map { keyValueStore.getData($0.uniqueId, transaction: tx) }
guard let fromValue = sessionBlob.fromValue else {
return
}
if sessionBlob.intoValue == nil {
keyValueStore.setData(fromValue, key: targetRecipient.uniqueId, transaction: tx)
}
keyValueStore.removeValue(forKey: recipient.uniqueId, transaction: tx)
}
public func deleteAllSessions(for serviceId: ServiceId, tx: DBWriteTransaction) {
Logger.info("deleting all sessions for \(serviceId)")
switch recipientIdFinder.recipientUniqueId(for: serviceId, tx: tx) {
case .none, .some(.failure(.mustNotUsePniBecauseAciExists)):
// There can't possibly be any sessions that need to be deleted.
return
case .some(.success(let recipientUniqueId)):
owsAssertDebug(!recipientUniqueId.isEmpty)
deleteAllSessions(for: recipientUniqueId, tx: tx)
}
}
public func deleteAllSessions(for recipientUniqueId: RecipientUniqueId, tx: DBWriteTransaction) {
keyValueStore.removeValue(forKey: recipientUniqueId, transaction: tx)
}
public func archiveAllSessions(for serviceId: ServiceId, tx: DBWriteTransaction) {
Logger.info("archiving all sessions for \(serviceId)")
switch recipientIdFinder.recipientUniqueId(for: serviceId, tx: tx) {
case .none, .some(.failure(.mustNotUsePniBecauseAciExists)):
// There can't possibly be any sessions that need to be archived.
return
case .some(.success(let recipientUniqueId)):
archiveAllSessions(for: recipientUniqueId, tx: tx)
}
}
public func archiveAllSessions(for address: SignalServiceAddress, tx: DBWriteTransaction) {
Logger.info("archiving all sessions for \(address)")
switch recipientIdFinder.recipientUniqueId(for: address, tx: tx) {
case .none, .some(.failure(.mustNotUsePniBecauseAciExists)):
// There can't possibly be any sessions that need to be archived.
return
case .some(.success(let recipientUniqueId)):
archiveAllSessions(for: recipientUniqueId, tx: tx)
}
}
private func archiveAllSessions(for recipientUniqueId: RecipientUniqueId, tx: DBWriteTransaction) {
owsAssertDebug(!recipientUniqueId.isEmpty)
guard let dictionary = loadAllSerializedSessions(for: recipientUniqueId, tx: tx) else {
// We never had a session for this account in the first place.
return
}
let newDictionary: SessionsByDeviceDictionary = dictionary.mapValues { record in
guard let data = serializedSession(fromDatabaseRepresentation: record) else {
// We've already logged an error; skip this session.
return record
}
do {
let session = try SessionRecord(bytes: data)
session.archiveCurrentState()
return session.serialize() as NSData
} catch {
owsFailDebug("\(error)")
return record
}
}
saveSerializedSessions(newDictionary, for: recipientUniqueId, tx: tx)
}
public func resetSessionStore(tx: DBWriteTransaction) {
Logger.warn("resetting session store")
keyValueStore.removeAll(transaction: tx)
}
public func removeAll(tx: DBWriteTransaction) {
keyValueStore.removeAll(transaction: tx)
}
}
extension SSKSessionStore {
public func loadSession(
for serviceId: ServiceId,
deviceId: DeviceId,
tx: DBReadTransaction
) throws -> SessionRecord? {
guard let serializedData = try loadSerializedSession(for: serviceId, deviceId: deviceId.uint32Value, tx: tx) else {
return nil
}
return try SessionRecord(bytes: serializedData)
}
fileprivate func storeSession(
_ record: SessionRecord,
for serviceId: ServiceId,
deviceId: UInt32,
tx: DBWriteTransaction
) throws {
try storeSerializedSession(record.serialize(), for: serviceId, deviceId: deviceId, tx: tx)
}
public func archiveSession(for serviceId: ServiceId, deviceId: DeviceId, tx: DBWriteTransaction) {
do {
guard let session = try loadSession(for: serviceId, deviceId: deviceId, tx: tx) else {
return
}
session.archiveCurrentState()
try storeSession(session, for: serviceId, deviceId: deviceId.uint32Value, tx: tx)
} catch {
owsFailDebug("\(error)")
}
}
}
extension SSKSessionStore: LibSignalClient.SessionStore {
public func loadSession(for address: ProtocolAddress, context: StoreContext) throws -> SessionRecord? {
return try loadSession(for: address.serviceId, deviceId: address.deviceIdObj, tx: context.asTransaction)
}
public func loadExistingSessions(
for addresses: [ProtocolAddress],
context: StoreContext
) throws -> [SessionRecord] {
return try addresses.map { address in
guard let session = try loadSession(for: address, context: context) else {
throw SignalError.sessionNotFound("\(address)")
}
return session
}
}
public func storeSession(_ record: SessionRecord, for address: ProtocolAddress, context: StoreContext) throws {
try storeSession(record, for: address.serviceId, deviceId: address.deviceId, tx: context.asTransaction)
}
}
#if TESTABLE_BUILD
extension SSKSessionStore {
// Available through `@testable import`
internal var keyValueStoreForTesting: KeyValueStore {
self.keyValueStore
}
}
#endif

View File

@ -82,7 +82,7 @@ public class SenderKeyStore {
// comparing a set of (deviceId, registrationId) structs, we should be able
// to detect reused deviceIds that will need an SKDM.
let registrationId = try sessionStore.loadSession(
for: serviceId,
forServiceId: serviceId,
deviceId: deviceId,
tx: tx
)?.remoteRegistrationId()

View File

@ -0,0 +1,337 @@
//
// Copyright 2025 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//
import Foundation
import GRDB
public import LibSignalClient
struct SessionRecord: Codable, FetchableRecord, PersistableRecord {
static let databaseTableName: String = "Session"
let id: Int64
var recipientId: SignalRecipient.RowId
let localIdentity: OWSIdentity
let deviceId: DeviceId
/// May be nil if there was a legacy session.
var serializedRecord: Data?
enum CodingKeys: String, CodingKey {
case id
case recipientId
case localIdentity
case deviceId
case serializedRecord
}
enum Columns {
static let recipientId = Column(CodingKeys.recipientId.rawValue)
static let localIdentity = Column(CodingKeys.localIdentity.rawValue)
static let deviceId = Column(CodingKeys.deviceId.rawValue)
static let serializedRecord = Column(CodingKeys.serializedRecord.rawValue)
}
}
struct SessionStore {
func hasSessionRecords(
forRecipientId recipientId: SignalRecipient.RowId,
localIdentity: OWSIdentity,
tx: DBReadTransaction,
) -> Bool {
let sessionRecords = fetchSessionRecords(
forRecipientId: recipientId,
localIdentity: localIdentity,
tx: tx,
)
return !sessionRecords.isEmpty
}
func mergeRecipientId(
_ recipientId: SignalRecipient.RowId,
into targetRecipientId: SignalRecipient.RowId,
localIdentity: OWSIdentity,
tx: DBWriteTransaction,
) {
if hasSessionRecords(forRecipientId: targetRecipientId, localIdentity: localIdentity, tx: tx) {
// There's already sessions -- prefers those instead of ours.
deleteSessions(forRecipientId: recipientId, localIdentity: localIdentity, tx: tx)
} else {
// There's no sessions -- move ours and reuse them.
let sessionRecords = fetchSessionRecords(
forRecipientId: recipientId,
localIdentity: localIdentity,
tx: tx,
)
for var sessionRecord in sessionRecords {
sessionRecord.recipientId = targetRecipientId
failIfThrows { try sessionRecord.update(tx.database) }
}
}
}
private func buildQuery(
recipientId: SignalRecipient.RowId,
localIdentity: OWSIdentity,
deviceId: DeviceId? = nil,
) -> QueryInterfaceRequest<SessionRecord> {
var result = SessionRecord.filter(SessionRecord.Columns.recipientId == recipientId)
result = result.filter(SessionRecord.Columns.localIdentity == localIdentity.rawValue)
if let deviceId {
result = result.filter(SessionRecord.Columns.deviceId == deviceId.rawValue)
}
return result
}
private func fetchSessionRecords(
forRecipientId recipientId: SignalRecipient.RowId,
localIdentity: OWSIdentity,
deviceId: DeviceId? = nil,
tx: DBReadTransaction,
) -> [SessionRecord] {
return failIfThrows {
return try buildQuery(
recipientId: recipientId,
localIdentity: localIdentity,
deviceId: deviceId,
).fetchAll(tx.database)
}
}
func fetchSession(
forRecipientId recipientId: SignalRecipient.RowId,
localIdentity: OWSIdentity,
deviceId: DeviceId,
tx: DBReadTransaction,
) throws -> LibSignalClient.SessionRecord? {
return try (fetchSessionRecords(
forRecipientId: recipientId,
localIdentity: localIdentity,
deviceId: deviceId,
tx: tx,
).first?.serializedRecord).map(LibSignalClient.SessionRecord.init(bytes:))
}
func archiveSessions(
forRecipientId recipientId: SignalRecipient.RowId,
localIdentity: OWSIdentity,
tx: DBWriteTransaction,
) {
_archiveSessions(forRecipientId: recipientId, localIdentity: localIdentity, deviceId: nil, tx: tx)
}
fileprivate func _archiveSessions(
forRecipientId recipientId: SignalRecipient.RowId,
localIdentity: OWSIdentity,
deviceId: DeviceId?,
tx: DBWriteTransaction,
) {
let sessionRecords = fetchSessionRecords(
forRecipientId: recipientId,
localIdentity: localIdentity,
deviceId: deviceId,
tx: tx,
)
for var sessionRecord in sessionRecords {
guard let serializedRecord = sessionRecord.serializedRecord else {
Logger.warn("couldn't decode legacy session to archive it; leaving it as-is")
continue
}
let libSignalSessionRecord: LibSignalClient.SessionRecord
do {
libSignalSessionRecord = try LibSignalClient.SessionRecord(bytes: serializedRecord)
} catch {
owsFailDebug("couldn't decode session to archive it: \(error)")
continue
}
libSignalSessionRecord.archiveCurrentState()
sessionRecord.serializedRecord = libSignalSessionRecord.serialize()
failIfThrows { try sessionRecord.update(tx.database) }
}
}
func deleteSessions(
forRecipientId recipientId: SignalRecipient.RowId,
localIdentity: OWSIdentity,
tx: DBWriteTransaction,
) {
failIfThrows {
_ = try buildQuery(
recipientId: recipientId,
localIdentity: localIdentity,
deviceId: nil,
).deleteAll(tx.database)
}
}
func upsertSession(
forRecipientId recipientId: SignalRecipient.RowId,
deviceId: DeviceId,
localIdentity: OWSIdentity,
recordData: Data,
tx: DBWriteTransaction,
) {
failIfThrows {
try tx.database.execute(
sql: """
INSERT OR REPLACE INTO \(SessionRecord.databaseTableName) (
\(SessionRecord.Columns.recipientId.name),
\(SessionRecord.Columns.deviceId.name),
\(SessionRecord.Columns.localIdentity.name),
\(SessionRecord.Columns.serializedRecord.name)
) VALUES (?, ?, ?, ?)
""",
arguments: [
recipientId,
deviceId.rawValue,
localIdentity.rawValue,
recordData,
],
)
}
}
func deleteAllSessions(tx: DBWriteTransaction) {
failIfThrows { _ = try SessionRecord.deleteAll(tx.database) }
}
}
public class SessionManagerForIdentity: LibSignalClient.SessionStore {
private let identity: OWSIdentity
private let recipientIdFinder: RecipientIdFinder
private let sessionStore: SessionStore
init(
identity: OWSIdentity,
recipientIdFinder: RecipientIdFinder,
sessionStore: SessionStore,
) {
self.identity = identity
self.recipientIdFinder = recipientIdFinder
self.sessionStore = sessionStore
}
func archiveSession(forServiceId serviceId: ServiceId, deviceId: DeviceId, tx: DBWriteTransaction) {
Logger.info("archiving session for \(serviceId).\(deviceId)")
self._archiveSessions(
recipientIdResult: self.recipientIdFinder.recipientId(for: serviceId, tx: tx),
deviceId: deviceId,
tx: tx,
)
}
public func archiveSessions(forServiceId serviceId: ServiceId, tx: DBWriteTransaction) {
Logger.info("archiving all sessions for \(serviceId)")
self._archiveSessions(
recipientIdResult: self.recipientIdFinder.recipientId(for: serviceId, tx: tx),
deviceId: nil,
tx: tx,
)
}
func archiveSessions(forAddress address: SignalServiceAddress, tx: DBWriteTransaction) {
Logger.info("archiving all sessions for \(address)")
self._archiveSessions(
recipientIdResult: self.recipientIdFinder.recipientId(for: address, tx: tx),
deviceId: nil,
tx: tx,
)
}
private func _archiveSessions(
recipientIdResult: Result<SignalRecipient.RowId, RecipientIdError>?,
deviceId: DeviceId?,
tx: DBWriteTransaction,
) {
switch recipientIdResult {
case .none, .some(.failure(.mustNotUsePniBecauseAciExists)):
// There can't possibly be any sessions that need to be archived.
return
case .some(.success(let recipientId)):
self.sessionStore._archiveSessions(
forRecipientId: recipientId,
localIdentity: self.identity,
deviceId: deviceId,
tx: tx,
)
}
}
public func deleteSessions(forServiceId serviceId: ServiceId, tx: DBWriteTransaction) {
switch self.recipientIdFinder.recipientId(for: serviceId, tx: tx) {
case .none, .some(.failure(.mustNotUsePniBecauseAciExists)):
// There can't possibly be any sessions that need to be deleted.
return
case .some(.success(let recipientId)):
self.sessionStore.deleteSessions(forRecipientId: recipientId, localIdentity: self.identity, tx: tx)
}
}
func loadSession(
forServiceId serviceId: ServiceId,
deviceId: DeviceId,
tx: DBReadTransaction,
) throws -> LibSignalClient.SessionRecord? {
switch self.recipientIdFinder.recipientId(for: serviceId, tx: tx) {
case .none:
return nil
case .some(.success(let recipientId)):
return try self.sessionStore.fetchSession(
forRecipientId: recipientId,
localIdentity: self.identity,
deviceId: deviceId,
tx: tx,
)
case .some(.failure(let error)):
switch error {
case .mustNotUsePniBecauseAciExists:
throw error
}
}
}
public func loadSession(
for address: LibSignalClient.ProtocolAddress,
context: any LibSignalClient.StoreContext,
) throws -> LibSignalClient.SessionRecord? {
return try loadSession(
forServiceId: address.serviceId,
deviceId: address.deviceIdObj,
tx: context.asTransaction,
)
}
public func loadExistingSessions(
for addresses: [LibSignalClient.ProtocolAddress],
context: any LibSignalClient.StoreContext,
) throws -> [LibSignalClient.SessionRecord] {
return try addresses.map { address in
guard let session = try loadSession(for: address, context: context) else {
throw SignalError.sessionNotFound("\(address)")
}
return session
}
}
public func storeSession(
_ record: LibSignalClient.SessionRecord,
for address: LibSignalClient.ProtocolAddress,
context: any LibSignalClient.StoreContext,
) throws {
switch recipientIdFinder.ensureRecipientId(for: address.serviceId, tx: context.asTransaction) {
case .success(let recipientId):
self.sessionStore.upsertSession(
forRecipientId: recipientId,
deviceId: address.deviceIdObj,
localIdentity: self.identity,
recordData: record.serialize(),
tx: context.asTransaction,
)
case .failure(let error):
switch error {
case .mustNotUsePniBecauseAciExists:
throw error
}
}
}
}

View File

@ -6,7 +6,7 @@
/// Wraps the stores for 1:1 sessions that use the Signal Protocol (Double Ratchet + X3DH).
public struct SignalProtocolStore {
public let sessionStore: SignalSessionStore
public let sessionStore: SessionManagerForIdentity
public let preKeyStore: PreKeyStoreImpl
public let signedPreKeyStore: SignedPreKeyStoreImpl
public let kyberPreKeyStore: KyberPreKeyStoreImpl
@ -16,9 +16,10 @@ public struct SignalProtocolStore {
identity: OWSIdentity,
preKeyStore: PreKeyStore,
recipientIdFinder: RecipientIdFinder,
sessionStore: SessionStore,
) -> Self {
return Self(
sessionStore: SSKSessionStore(for: identity, recipientIdFinder: recipientIdFinder),
sessionStore: SessionManagerForIdentity(identity: identity, recipientIdFinder: recipientIdFinder, sessionStore: sessionStore),
preKeyStore: PreKeyStoreImpl(for: identity, preKeyStore: preKeyStore),
signedPreKeyStore: SignedPreKeyStoreImpl(for: identity, preKeyStore: preKeyStore),
kyberPreKeyStore: KyberPreKeyStoreImpl(for: identity, dateProvider: dateProvider, preKeyStore: preKeyStore),
@ -31,6 +32,7 @@ public struct SignalProtocolStoreManager {
let aciProtocolStore: SignalProtocolStore
let pniProtocolStore: SignalProtocolStore
let preKeyStore: PreKeyStore
let sessionStore: SessionStore
public func signalProtocolStore(for identity: OWSIdentity) -> SignalProtocolStore {
switch identity {
@ -43,11 +45,11 @@ public struct SignalProtocolStoreManager {
public func removeAllKeys(tx: DBWriteTransaction) {
for signalProtocolStore in [aciProtocolStore, pniProtocolStore] {
signalProtocolStore.sessionStore.removeAll(tx: tx)
signalProtocolStore.preKeyStore.removeMetadata(tx: tx)
signalProtocolStore.signedPreKeyStore.removeMetadata(tx: tx)
signalProtocolStore.kyberPreKeyStore.removeMetadata(tx: tx)
}
self.sessionStore.deleteAllSessions(tx: tx)
self.preKeyStore.removeAll(tx: tx)
}
}

View File

@ -1,62 +0,0 @@
//
// Copyright 2023 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//
import Foundation
public import LibSignalClient
public protocol SignalSessionStore: LibSignalClient.SessionStore {
func mightContainSession(
for recipient: SignalRecipient,
tx: DBReadTransaction
) -> Bool
func mergeRecipient(
_ recipient: SignalRecipient,
into targetRecipient: SignalRecipient,
tx: DBWriteTransaction
)
func archiveAllSessions(
for serviceId: ServiceId,
tx: DBWriteTransaction
)
/// Deprecated. Prefer the variant that accepts a ServiceId.
func archiveAllSessions(
for address: SignalServiceAddress,
tx: DBWriteTransaction
)
func archiveSession(
for serviceId: ServiceId,
deviceId: DeviceId,
tx: DBWriteTransaction
)
func loadSession(
for serviceId: ServiceId,
deviceId: DeviceId,
tx: DBReadTransaction
) throws -> SessionRecord?
func loadSession(
for address: ProtocolAddress,
context: StoreContext
) throws -> SessionRecord?
func resetSessionStore(tx: DBWriteTransaction)
func deleteAllSessions(
for serviceId: ServiceId,
tx: DBWriteTransaction
)
func deleteAllSessions(
for recipientUniqueId: RecipientUniqueId,
tx: DBWriteTransaction
)
func removeAll(tx: DBWriteTransaction)
}

View File

@ -400,6 +400,7 @@ public enum DatabaseRecovery {
StoryRecipient.databaseTableName,
PreKey.databaseTableName,
KyberPreKeyUseRecord.databaseTableName,
SignalServiceKit.SessionRecord.databaseTableName,
]
/// Copy tables that must be copied flawlessly. Operation throws if any tables fail.

View File

@ -315,6 +315,7 @@ public class GRDBSchemaMigrator {
case addPinnedAtTimestampToPinnedMessageTable
case dropTSAttachment
case deprecateStoredShouldStartExpireTimer
case addSession
// NOTE: Every time we add a migration id, consider
// incrementing grdbSchemaVersionLatest.
@ -437,7 +438,7 @@ public class GRDBSchemaMigrator {
}
public static let grdbSchemaVersionDefault: UInt = 0
public static let grdbSchemaVersionLatest: UInt = 137
public static let grdbSchemaVersionLatest: UInt = 138
private class DatabaseMigratorWrapper {
var migrator = DatabaseMigrator()
@ -4952,6 +4953,15 @@ public class GRDBSchemaMigrator {
return .success(())
}
migrator.registerMigration(.addSession) { tx in
try createSession(tx: tx)
if BuildFlags.migrateDeprecatedSessions {
try migrateSessions(tx: tx)
}
try dropOldSessions(tx: tx)
return .success(())
}
// MARK: - Schema Migration Insertion Point
}
@ -7250,6 +7260,105 @@ public class GRDBSchemaMigrator {
arguments: ["TSStorageUserAccountCollection"],
) ?? false
}
static func createSession(tx: DBWriteTransaction) throws {
try tx.database.create(table: "Session") {
$0.column("id", .integer).primaryKey().notNull()
$0.column("recipientId", .integer).notNull()
.references("model_SignalRecipient", column: "id", onDelete: .cascade, onUpdate: .cascade)
$0.column("localIdentity", .integer).notNull()
$0.column("deviceId", .integer).notNull()
$0.column("serializedRecord", .blob)
$0.check(sql: #"1 <= "deviceId" AND "deviceId" <= 127"#)
}
// For fetching session(s) for a recipient.
try tx.database.create(
index: "Session_Unique",
on: "Session",
columns: ["recipientId", "localIdentity", "deviceId"],
options: [.unique],
)
}
static func migrateSessions(tx: DBWriteTransaction) throws {
// If these ever change, you'll need to add a new migration to update the
// Session table and replace the old constants with the new constants.
assert(OWSIdentity.aci.rawValue == 0)
assert(OWSIdentity.pni.rawValue == 1)
try migrateSessions(in: "TSStorageManagerSessionStoreCollection", identity: 0, tx: tx)
try migrateSessions(in: "TSStorageManagerPNISessionStoreCollection", identity: 1, tx: tx)
}
static func migrateSessions(
in collection: String,
identity: Int64,
tx: DBWriteTransaction,
) throws {
let keys = try String.fetchAll(
tx.database,
sql: "SELECT key FROM keyvalue WHERE collection = ?",
arguments: [collection],
)
for key in keys { try autoreleasepool {
let dataValue = try Data.fetchOne(
tx.database,
sql: "SELECT value FROM keyvalue WHERE collection = ? AND key = ?",
arguments: [collection, key],
)!
let sessionDictionary: [Int32: Data?]
let decodedValue = try? NSKeyedUnarchiver.unarchivedObject(
ofClasses: [NSDictionary.self, NSNumber.self, NSData.self],
from: dataValue,
) as? [Int32: Data]
if let decodedValue {
sessionDictionary = decodedValue
} else {
// We expect some failures (for legacy data), and if there are failures, we
// want to remember that there was a session, even though we can't do
// anything with that session. (See also `hasSessionRecords`).
Logger.warn("Storing nil for \(key) in \(collection) that couldn't be decoded")
sessionDictionary = [1: nil]
}
let recipientId = try Int64.fetchOne(
tx.database,
sql: "SELECT id FROM model_SignalRecipient WHERE uniqueId = ?",
arguments: [key],
)
guard let recipientId else {
// If we can't find the SignalRecipient, these sessions aren't reachable,
// so we don't need to keep them. (Foreign key constraints will enforce
// this moving forward.)
Logger.warn("Skipping \(key) in \(collection) that's been orphaned")
return
}
for (deviceId, serializedRecord) in sessionDictionary {
guard deviceId >= 1 && deviceId <= 127 else {
Logger.warn("Skipping \(deviceId) for \(key) in \(collection) that's not valid")
continue
}
try tx.database.execute(
sql: "INSERT INTO Session (recipientId, localIdentity, deviceId, serializedRecord) VALUES (?, ?, ?, ?)",
arguments: [
recipientId,
identity,
deviceId,
serializedRecord,
],
)
}
}}
}
static func dropOldSessions(tx: DBWriteTransaction) throws {
let collections = [
"TSStorageManagerSessionStoreCollection",
"TSStorageManagerPNISessionStoreCollection",
]
for collection in collections {
try tx.database.execute(sql: "DELETE FROM keyvalue WHERE collection = ?", arguments: [collection])
}
}
}
// MARK: -

View File

@ -34,29 +34,50 @@ public final class RecipientIdFinder {
guard let recipient = recipientDatabaseTable.fetchRecipient(serviceId: serviceId, transaction: tx) else {
return nil
}
return recipientUniqueIdResult(for: serviceId, recipient: recipient)
return validateRecipient(recipient, for: serviceId).map(\.uniqueId)
}
public func recipientUniqueId(for address: SignalServiceAddress, tx: DBReadTransaction) -> Result<RecipientUniqueId, RecipientIdError>? {
guard
let recipient = DependenciesBridge.shared.recipientDatabaseTable
.fetchRecipient(address: address, tx: tx)
else {
guard let recipient = recipientDatabaseTable.fetchRecipient(address: address, tx: tx) else {
return nil
}
return recipientUniqueIdResult(for: address.serviceId, recipient: recipient)
return validateRecipient(recipient, for: address.serviceId).map(\.uniqueId)
}
public func ensureRecipientUniqueId(for serviceId: ServiceId, tx: DBWriteTransaction) -> Result<RecipientUniqueId, RecipientIdError> {
let recipient = recipientFetcher.fetchOrCreate(serviceId: serviceId, tx: tx)
return recipientUniqueIdResult(for: serviceId, recipient: recipient)
return ensureRecipient(for: serviceId, tx: tx).map(\.uniqueId)
}
private func recipientUniqueIdResult(for serviceId: ServiceId?, recipient: SignalRecipient) -> Result<RecipientUniqueId, RecipientIdError> {
public func recipientId(for serviceId: ServiceId, tx: DBReadTransaction) -> Result<SignalRecipient.RowId, RecipientIdError>? {
guard let recipient = recipientDatabaseTable.fetchRecipient(serviceId: serviceId, transaction: tx) else {
return nil
}
return validateRecipient(recipient, for: serviceId).map(\.id)
}
public func recipientId(for address: SignalServiceAddress, tx: DBReadTransaction) -> Result<SignalRecipient.RowId, RecipientIdError>? {
guard let recipient = recipientDatabaseTable.fetchRecipient(address: address, tx: tx) else {
return nil
}
return validateRecipient(recipient, for: address.serviceId).map(\.id)
}
public func ensureRecipientId(for serviceId: ServiceId, tx: DBWriteTransaction) -> Result<SignalRecipient.RowId, RecipientIdError> {
return ensureRecipient(for: serviceId, tx: tx).map(\.id)
}
public func ensureRecipient(for serviceId: ServiceId, tx: DBWriteTransaction) -> Result<SignalRecipient, RecipientIdError> {
let recipient = recipientFetcher.fetchOrCreate(serviceId: serviceId, tx: tx)
return validateRecipient(recipient, for: serviceId)
}
private func validateRecipient(
_ recipient: SignalRecipient,
for serviceId: ServiceId?,
) -> Result<SignalRecipient, RecipientIdError> {
if serviceId is Pni, recipient.aciString != nil {
return .failure(.mustNotUsePniBecauseAciExists)
}
return .success(recipient.uniqueId)
return .success(recipient)
}
}

View File

@ -266,12 +266,16 @@ struct FakeSignalClient: TestSignalClient {
struct LocalSignalClient: TestSignalClient {
let identity: OWSIdentity
let _preKeyStore: PreKeyStore
let _sessionStore: SSKSessionStore
let _sessionStore: SessionManagerForIdentity
init(identity: OWSIdentity = .aci) {
self.identity = identity
self._preKeyStore = PreKeyStore()
self._sessionStore = SSKSessionStore(for: identity, recipientIdFinder: DependenciesBridge.shared.recipientIdFinder)
self._sessionStore = SessionManagerForIdentity(
identity: identity,
recipientIdFinder: DependenciesBridge.shared.recipientIdFinder,
sessionStore: SessionStore(),
)
}
var identityKeyPair: ECKeyPair {

View File

@ -45,14 +45,28 @@ final class PreKeyTaskTests: SSKBaseTest {
mockAPIClient = .init()
mockDateProvider = .init()
mockDb = InMemoryDB()
let sessionStore = SignalServiceKit.SessionStore()
mockPreKeyStore = PreKeyStore()
mockAciProtocolStore = .mock(identity: .aci, preKeyStore: mockPreKeyStore)
mockPniProtocolStore = .mock(identity: .pni, preKeyStore: mockPreKeyStore)
mockAciProtocolStore = .build(
dateProvider: mockDateProvider.targetDate,
identity: .aci,
preKeyStore: mockPreKeyStore,
recipientIdFinder: recipientIdFinder,
sessionStore: sessionStore,
)
mockPniProtocolStore = .build(
dateProvider: mockDateProvider.targetDate,
identity: .pni,
preKeyStore: mockPreKeyStore,
recipientIdFinder: recipientIdFinder,
sessionStore: sessionStore,
)
mockProtocolStoreManager = SignalProtocolStoreManager(
aciProtocolStore: mockAciProtocolStore,
pniProtocolStore: mockPniProtocolStore,
preKeyStore: mockPreKeyStore,
sessionStore: sessionStore,
)
taskManager = PreKeyTaskManager(

View File

@ -28,16 +28,13 @@ private class MockStorageServiceManager: StorageServiceManager {
}
private class TestDependencies {
let aciSessionStore: SignalSessionStore
var aciSessionStoreKeyValueStore: KeyValueStore {
KeyValueStore(collection: "TSStorageManagerSessionStoreCollection")
}
let identityManager: MockIdentityManager
let mockDB = InMemoryDB()
let recipientMerger: RecipientMerger
let recipientDatabaseTable = RecipientDatabaseTable()
let recipientFetcher: RecipientFetcher
let recipientIdFinder: RecipientIdFinder
let sessionStore: SignalServiceKit.SessionStore
let threadAssociatedDataStore: MockThreadAssociatedDataStore
let threadStore: MockThreadStore
let threadMerger: ThreadMerger
@ -50,10 +47,10 @@ private class TestDependencies {
searchableNameIndexer: searchableNameIndexer,
)
recipientIdFinder = RecipientIdFinder(recipientDatabaseTable: recipientDatabaseTable, recipientFetcher: recipientFetcher)
aciSessionStore = SSKSessionStore(for: .aci, recipientIdFinder: recipientIdFinder)
identityManager = MockIdentityManager(recipientIdFinder: recipientIdFinder)
identityManager.recipientIdentities = [:]
identityManager.sessionSwitchoverMessages = []
sessionStore = SignalServiceKit.SessionStore()
threadAssociatedDataStore = MockThreadAssociatedDataStore()
threadStore = MockThreadStore()
threadMerger = ThreadMerger.forUnitTests(
@ -61,7 +58,6 @@ private class TestDependencies {
threadStore: threadStore
)
recipientMerger = RecipientMergerImpl(
aciSessionStore: aciSessionStore,
blockedRecipientStore: BlockedRecipientStore(),
identityManager: identityManager,
observers: RecipientMergerImpl.Observers(
@ -72,6 +68,7 @@ private class TestDependencies {
recipientDatabaseTable: recipientDatabaseTable,
recipientFetcher: recipientFetcher,
searchableNameIndexer: searchableNameIndexer,
sessionStore: sessionStore,
storageServiceManager: storageServiceManager,
storyRecipientStore: StoryRecipientStore()
)
@ -253,7 +250,13 @@ class RecipientMergerTest: XCTestCase {
createdAt: Date(),
verificationState: .default
)
d.aciSessionStoreKeyValueStore.setData(Data(), key: recipient.uniqueId, transaction: tx)
d.sessionStore.upsertSession(
forRecipientId: recipient.id,
deviceId: .primary,
localIdentity: .aci,
recordData: Data(),
tx: tx,
)
}
}
@ -266,7 +269,7 @@ class RecipientMergerTest: XCTestCase {
XCTAssertEqual(d.identityManager.identityChangeInfoMessages, testCase.shouldInsertEvent ? [ac1] : [])
XCTAssertEqual(try! d.identityManager.identityKey(for: ac1, tx: tx), ik1)
XCTAssertTrue(d.aciSessionStore.mightContainSession(for: mergedRecipient, tx: tx))
XCTAssertTrue(d.sessionStore.hasSessionRecords(forRecipientId: mergedRecipient.id, localIdentity: .aci, tx: tx))
}
}
}
@ -402,7 +405,13 @@ class RecipientMergerTest: XCTestCase {
d.mockDB.write { tx in
for recipientNumber in testCase.hasSession {
let recipient = recipients.dropFirst(recipientNumber - 1).first!
d.aciSessionStoreKeyValueStore.setData(Data(), key: recipient.uniqueId, transaction: tx)
d.sessionStore.upsertSession(
forRecipientId: recipient.id,
deviceId: .primary,
localIdentity: .aci,
recordData: Data(),
tx: tx,
)
let thread = TSContactThread(contactAddress: SignalServiceAddress(
serviceId: recipient.aci ?? recipient.pni,
phoneNumber: recipient.phoneNumber?.stringValue,
@ -444,10 +453,22 @@ class RecipientMergerTest: XCTestCase {
let d = TestDependencies()
let aciRecipient = d.mockDB.write { tx in
let aciRecipient = try! SignalRecipient.insertRecord(aci: aci, tx: tx)
d.aciSessionStoreKeyValueStore.setData(Data(), key: aciRecipient.uniqueId, transaction: tx)
d.sessionStore.upsertSession(
forRecipientId: aciRecipient.id,
deviceId: .primary,
localIdentity: .aci,
recordData: Data(),
tx: tx,
)
let pniRecipient = try! SignalRecipient.insertRecord(phoneNumber: phoneNumber, pni: pni, tx: tx)
d.aciSessionStoreKeyValueStore.setData(Data(), key: pniRecipient.uniqueId, transaction: tx)
d.sessionStore.upsertSession(
forRecipientId: pniRecipient.id,
deviceId: .primary,
localIdentity: .aci,
recordData: Data(),
tx: tx,
)
return aciRecipient
}

View File

@ -8,58 +8,6 @@ import LibSignalClient
@testable import SignalServiceKit
class SessionStoreTest: SSKBaseTest {
func testLegacySessionIsDropped() {
@objc(FakeLegacySession) class FakeLegacySession: NSObject, NSCoding {
override init() {
}
required init?(coder: NSCoder) {
fatalError("should never be deserialized")
}
func encode(with coder: NSCoder) {
// no properties
}
}
// We have to use the database-based KeyValueStore to test this
// because the in-memory one skips the archiving step.
let sessionStore = SSKSessionStore(for: .aci, recipientIdFinder: DependenciesBridge.shared.recipientIdFinder)
let recipient = write {
DependenciesBridge.shared.recipientFetcher.fetchOrCreate(serviceId: Aci.randomForTesting(), tx: $0)
}
// First make sure that if we write a "valid" session, it can be read.
let singleValidSessionData = try! NSKeyedArchiver.archivedData(withRootObject: [1: Data()], requiringSecureCoding: true)
write {
sessionStore.keyValueStoreForTesting.setData(singleValidSessionData, key: recipient.uniqueId, transaction: $0)
}
read {
XCTAssertTrue(sessionStore.mightContainSession(for: recipient, tx: $0))
XCTAssertNotNil(try! sessionStore.loadSession(for: recipient.aci!, deviceId: DeviceId(validating: 1)!, tx: $0))
}
// Then imitate a session store with a mix of legacy and modern sessions.
let sessions: NSDictionary = [1: FakeLegacySession(), 2: Data()]
let archiver = NSKeyedArchiver(requiringSecureCoding: false)
archiver.setClassName("SSKLegacySessionClassThatNoLongerExists", for: FakeLegacySession.self)
archiver.encode(sessions, forKey: NSKeyedArchiveRootObjectKey)
write {
sessionStore.keyValueStoreForTesting.setData(archiver.encodedData, key: recipient.uniqueId, transaction: $0)
}
read {
// There's something in the store...
XCTAssertTrue(sessionStore.mightContainSession(for: recipient, tx: $0))
// ...but it turns into nil on load.
XCTAssertNil(try! sessionStore.loadSession(for: recipient.aci!, deviceId: DeviceId(validating: 2)!, tx: $0))
}
}
}
class SessionStoreTest2: XCTestCase {
func testMaxUnacknowledgedSessionAge() throws {
let bob_address = try ProtocolAddress(name: "+14155550100", deviceId: 1)

View File

@ -1160,4 +1160,118 @@ class GRDBSchemaMigratorTest: XCTestCase {
XCTAssertEqual(callLinks[2][0] as String?, "Something")
}
}
private func keyedArchiverSessionData(deviceIds: [Int32]) -> Data {
let sessionDictionary = Dictionary(uniqueKeysWithValues: deviceIds.map { ($0, Data()) })
return keyedArchiverData(rootObject: sessionDictionary)
}
func testMigrateSessions() throws {
let databaseQueue = DatabaseQueue()
try databaseQueue.write { db in
try db.execute(sql: """
CREATE TABLE keyvalue (collection TEXT NOT NULL, key TEXT NOT NULL, value BLOB NOT NULL);
""")
try db.execute(sql: """
CREATE TABLE model_SignalRecipient (id INTEGER PRIMARY KEY, uniqueId TEXT NOT NULL);
""")
let recipient1UniqueId = UUID().uuidString
let recipient2UniqueId = UUID().uuidString
let recipient3UniqueId = UUID().uuidString
let recipient4UniqueId = UUID().uuidString
try db.execute(
sql: "INSERT INTO model_SignalRecipient (id, uniqueId) VALUES (?, ?)",
arguments: [1, recipient1UniqueId],
)
try db.execute(
sql: "INSERT INTO model_SignalRecipient (id, uniqueId) VALUES (?, ?)",
arguments: [2, recipient2UniqueId],
)
// Don't insert recipient3UniqueKey.
try db.execute(
sql: "INSERT INTO model_SignalRecipient (id, uniqueId) VALUES (?, ?)",
arguments: [4, recipient4UniqueId],
)
try db.execute(
sql: "INSERT INTO keyvalue (collection, key, value) VALUES (?, ?, ?)",
arguments: ["TSStorageManagerSessionStoreCollection", recipient1UniqueId, keyedArchiverSessionData(deviceIds: [1, 2, 128])],
)
try db.execute(
sql: "INSERT INTO keyvalue (collection, key, value) VALUES (?, ?, ?)",
arguments: ["TSStorageManagerPNISessionStoreCollection", recipient1UniqueId, keyedArchiverSessionData(deviceIds: [0, 2, 3])],
)
try db.execute(
sql: "INSERT INTO keyvalue (collection, key, value) VALUES (?, ?, ?)",
arguments: ["TSStorageManagerSessionStoreCollection", recipient2UniqueId, keyedArchiverSessionData(deviceIds: [1])],
)
try db.execute(
sql: "INSERT INTO keyvalue (collection, key, value) VALUES (?, ?, ?)",
arguments: ["TSStorageManagerSessionStoreCollection", recipient3UniqueId, keyedArchiverSessionData(deviceIds: [1])],
)
@objc(FakeLegacySession) class FakeLegacySession: NSObject, NSCoding {
override init() {}
required init?(coder: NSCoder) { fatalError("should never be deserialized") }
func encode(with coder: NSCoder) {}
}
let legacyArchivedData: Data
do {
let sessionDictionary: [Int32: AnyObject] = [1: FakeLegacySession(), 2: Data() as NSData]
let archiver = NSKeyedArchiver(requiringSecureCoding: false)
archiver.setClassName("SSKLegacySessionClassThatNoLongerExists", for: FakeLegacySession.self)
archiver.encode(sessionDictionary, forKey: NSKeyedArchiveRootObjectKey)
legacyArchivedData = archiver.encodedData
}
try db.execute(
sql: "INSERT INTO keyvalue (collection, key, value) VALUES (?, ?, ?)",
arguments: ["TSStorageManagerSessionStoreCollection", recipient4UniqueId, legacyArchivedData],
)
do {
let tx = DBWriteTransaction(database: db)
defer { tx.finalizeTransaction() }
try GRDBSchemaMigrator.createSession(tx: tx)
try GRDBSchemaMigrator.migrateSessions(tx: tx)
try GRDBSchemaMigrator.dropOldSessions(tx: tx)
}
let sessions = try Row.fetchAll(db, sql: "SELECT * FROM Session ORDER BY recipientId, localIdentity, deviceId")
XCTAssertEqual(sessions.count, 6)
XCTAssertEqual(sessions[0]["recipientId"] as Int64, 1)
XCTAssertEqual(sessions[0]["localIdentity"] as Int64, 0)
XCTAssertEqual(sessions[0]["deviceId"] as Int8, 1)
XCTAssertEqual(sessions[0]["serializedRecord"] as Data?, Data())
XCTAssertEqual(sessions[1]["recipientId"] as Int64, 1)
XCTAssertEqual(sessions[1]["localIdentity"] as Int64, 0)
XCTAssertEqual(sessions[1]["deviceId"] as Int8, 2)
XCTAssertEqual(sessions[1]["serializedRecord"] as Data?, Data())
XCTAssertEqual(sessions[2]["recipientId"] as Int64, 1)
XCTAssertEqual(sessions[2]["localIdentity"] as Int64, 1)
XCTAssertEqual(sessions[2]["deviceId"] as Int8, 2)
XCTAssertEqual(sessions[2]["serializedRecord"] as Data?, Data())
XCTAssertEqual(sessions[3]["recipientId"] as Int64, 1)
XCTAssertEqual(sessions[3]["localIdentity"] as Int64, 1)
XCTAssertEqual(sessions[3]["deviceId"] as Int8, 3)
XCTAssertEqual(sessions[3]["serializedRecord"] as Data?, Data())
XCTAssertEqual(sessions[4]["recipientId"] as Int64, 2)
XCTAssertEqual(sessions[4]["localIdentity"] as Int64, 0)
XCTAssertEqual(sessions[4]["deviceId"] as Int8, 1)
XCTAssertEqual(sessions[4]["serializedRecord"] as Data?, Data())
XCTAssertEqual(sessions[5]["recipientId"] as Int64, 4)
XCTAssertEqual(sessions[5]["localIdentity"] as Int64, 0)
XCTAssertEqual(sessions[5]["deviceId"] as Int8, 1)
XCTAssertEqual(sessions[5]["serializedRecord"] as Data?, nil)
}
}
}