Migrate sessions to a dedicated table
This commit is contained in:
parent
9d37668550
commit
4ae6bfe50d
@ -572,6 +572,7 @@
|
||||
500FE4E2288A373100FA090C /* BadgeGiftingAlreadyRedeemedSheet.swift in Sources */ = {isa = PBXBuildFile; fileRef = 500FE4E1288A373100FA090C /* BadgeGiftingAlreadyRedeemedSheet.swift */; };
|
||||
50101FB22B083C8100C648E4 /* ChatListSettingsButtonState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 50101FB12B083C8100C648E4 /* ChatListSettingsButtonState.swift */; };
|
||||
50101FB42B08447000C648E4 /* ChatListProxyButtonCreator.swift in Sources */ = {isa = PBXBuildFile; fileRef = 50101FB32B08447000C648E4 /* ChatListProxyButtonCreator.swift */; };
|
||||
501050BB2EB959A4005161CA /* SessionStore.swift in Sources */ = {isa = PBXBuildFile; fileRef = 501050BA2EB959A4005161CA /* SessionStore.swift */; };
|
||||
501052642BDAEEDC0097DDC5 /* MobileCoinExternal.pb.swift in Sources */ = {isa = PBXBuildFile; fileRef = 501052632BDAEEDC0097DDC5 /* MobileCoinExternal.pb.swift */; };
|
||||
501052672BDB22940097DDC5 /* PrivacyInfo.xcprivacy in Resources */ = {isa = PBXBuildFile; fileRef = 501052652BDB15B90097DDC5 /* PrivacyInfo.xcprivacy */; };
|
||||
501052692BDB232A0097DDC5 /* PrivacyInfo.xcprivacy in Resources */ = {isa = PBXBuildFile; fileRef = 501052682BDB232A0097DDC5 /* PrivacyInfo.xcprivacy */; };
|
||||
@ -1806,7 +1807,6 @@
|
||||
C1C4AA3329E7038D000CE9D3 /* EditManagerShims.swift in Sources */ = {isa = PBXBuildFile; fileRef = C1C4AA3229E7038D000CE9D3 /* EditManagerShims.swift */; };
|
||||
C1C7E4FB2BE0419300F196EE /* UploadMetadata.swift in Sources */ = {isa = PBXBuildFile; fileRef = C1C7E4FA2BE0419300F196EE /* UploadMetadata.swift */; };
|
||||
C1CA5F8E2BE2F21C00D733CA /* BackupArchiveDistributionListRecipientArchiver.swift in Sources */ = {isa = PBXBuildFile; fileRef = C1CA5F8D2BE2F21C00D733CA /* BackupArchiveDistributionListRecipientArchiver.swift */; };
|
||||
C1CD0E3A2A6B0D2700307F1A /* SignalSessionStore.swift in Sources */ = {isa = PBXBuildFile; fileRef = C1CD0E392A6B0D2700307F1A /* SignalSessionStore.swift */; };
|
||||
C1CD0E402A6B37BF00307F1A /* PreKeyStoreImplTest.swift in Sources */ = {isa = PBXBuildFile; fileRef = C1CD0E3F2A6B37BF00307F1A /* PreKeyStoreImplTest.swift */; };
|
||||
C1CF83D02B96C85E00CDC9C4 /* ChunkedOutputStreamTransform.swift in Sources */ = {isa = PBXBuildFile; fileRef = C1CF83CF2B96C85E00CDC9C4 /* ChunkedOutputStreamTransform.swift */; };
|
||||
C1CF83D22B9A1FCB00CDC9C4 /* GzipStreamTransform.swift in Sources */ = {isa = PBXBuildFile; fileRef = C1CF83D12B9A1FCB00CDC9C4 /* GzipStreamTransform.swift */; };
|
||||
@ -3739,7 +3739,6 @@
|
||||
F9C5CD33289453B300548EEE /* SignedPreKeyStoreImpl.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5CA55289453B100548EEE /* SignedPreKeyStoreImpl.swift */; };
|
||||
F9C5CD34289453B300548EEE /* SignalProtocolStore.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5CA56289453B100548EEE /* SignalProtocolStore.swift */; };
|
||||
F9C5CD37289453B300548EEE /* SenderKeyStore.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5CA59289453B100548EEE /* SenderKeyStore.swift */; };
|
||||
F9C5CD3C289453B300548EEE /* SSKSessionStore.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5CA5E289453B100548EEE /* SSKSessionStore.swift */; };
|
||||
F9C5CD52289453B300548EEE /* PreKeyStoreImpl.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5CA75289453B100548EEE /* PreKeyStoreImpl.swift */; };
|
||||
F9C5CD54289453B300548EEE /* SSKKeychainStorage.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5CA77289453B100548EEE /* SSKKeychainStorage.swift */; };
|
||||
F9C5CD56289453B300548EEE /* PendingViewedReceiptRecord.swift in Sources */ = {isa = PBXBuildFile; fileRef = F9C5CA79289453B100548EEE /* PendingViewedReceiptRecord.swift */; };
|
||||
@ -4716,6 +4715,7 @@
|
||||
500FE4E1288A373100FA090C /* BadgeGiftingAlreadyRedeemedSheet.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = BadgeGiftingAlreadyRedeemedSheet.swift; sourceTree = "<group>"; };
|
||||
50101FB12B083C8100C648E4 /* ChatListSettingsButtonState.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatListSettingsButtonState.swift; sourceTree = "<group>"; };
|
||||
50101FB32B08447000C648E4 /* ChatListProxyButtonCreator.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatListProxyButtonCreator.swift; sourceTree = "<group>"; };
|
||||
501050BA2EB959A4005161CA /* SessionStore.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SessionStore.swift; sourceTree = "<group>"; };
|
||||
501052632BDAEEDC0097DDC5 /* MobileCoinExternal.pb.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = MobileCoinExternal.pb.swift; sourceTree = "<group>"; };
|
||||
501052652BDB15B90097DDC5 /* PrivacyInfo.xcprivacy */ = {isa = PBXFileReference; lastKnownFileType = text.xml; path = PrivacyInfo.xcprivacy; sourceTree = "<group>"; };
|
||||
501052682BDB232A0097DDC5 /* PrivacyInfo.xcprivacy */ = {isa = PBXFileReference; lastKnownFileType = text.xml; path = PrivacyInfo.xcprivacy; sourceTree = "<group>"; };
|
||||
@ -5963,7 +5963,6 @@
|
||||
C1C4AA3229E7038D000CE9D3 /* EditManagerShims.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = EditManagerShims.swift; sourceTree = "<group>"; };
|
||||
C1C7E4FA2BE0419300F196EE /* UploadMetadata.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = UploadMetadata.swift; sourceTree = "<group>"; };
|
||||
C1CA5F8D2BE2F21C00D733CA /* BackupArchiveDistributionListRecipientArchiver.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = BackupArchiveDistributionListRecipientArchiver.swift; sourceTree = "<group>"; };
|
||||
C1CD0E392A6B0D2700307F1A /* SignalSessionStore.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SignalSessionStore.swift; sourceTree = "<group>"; };
|
||||
C1CD0E3F2A6B37BF00307F1A /* PreKeyStoreImplTest.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PreKeyStoreImplTest.swift; sourceTree = "<group>"; };
|
||||
C1CF83CF2B96C85E00CDC9C4 /* ChunkedOutputStreamTransform.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChunkedOutputStreamTransform.swift; sourceTree = "<group>"; };
|
||||
C1CF83D12B9A1FCB00CDC9C4 /* GzipStreamTransform.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GzipStreamTransform.swift; sourceTree = "<group>"; };
|
||||
@ -7934,7 +7933,6 @@
|
||||
F9C5CA55289453B100548EEE /* SignedPreKeyStoreImpl.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = SignedPreKeyStoreImpl.swift; sourceTree = "<group>"; };
|
||||
F9C5CA56289453B100548EEE /* SignalProtocolStore.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = SignalProtocolStore.swift; sourceTree = "<group>"; };
|
||||
F9C5CA59289453B100548EEE /* SenderKeyStore.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = SenderKeyStore.swift; sourceTree = "<group>"; };
|
||||
F9C5CA5E289453B100548EEE /* SSKSessionStore.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = SSKSessionStore.swift; sourceTree = "<group>"; };
|
||||
F9C5CA75289453B100548EEE /* PreKeyStoreImpl.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PreKeyStoreImpl.swift; sourceTree = "<group>"; };
|
||||
F9C5CA77289453B100548EEE /* SSKKeychainStorage.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = SSKKeychainStorage.swift; sourceTree = "<group>"; };
|
||||
F9C5CA79289453B100548EEE /* PendingViewedReceiptRecord.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PendingViewedReceiptRecord.swift; sourceTree = "<group>"; };
|
||||
@ -14881,10 +14879,9 @@
|
||||
50589CDD2E8C44D5003EF42A /* PreKeyStore.swift */,
|
||||
F9C5CA75289453B100548EEE /* PreKeyStoreImpl.swift */,
|
||||
F9C5CA59289453B100548EEE /* SenderKeyStore.swift */,
|
||||
501050BA2EB959A4005161CA /* SessionStore.swift */,
|
||||
F9C5CA56289453B100548EEE /* SignalProtocolStore.swift */,
|
||||
C1CD0E392A6B0D2700307F1A /* SignalSessionStore.swift */,
|
||||
F9C5CA55289453B100548EEE /* SignedPreKeyStoreImpl.swift */,
|
||||
F9C5CA5E289453B100548EEE /* SSKSessionStore.swift */,
|
||||
);
|
||||
path = AxolotlStore;
|
||||
sourceTree = "<group>";
|
||||
@ -19328,6 +19325,7 @@
|
||||
F9C5CC96289453B300548EEE /* SessionRecord.pb.swift in Sources */,
|
||||
725465242BA017D500EABFD2 /* SessionResetJob.swift in Sources */,
|
||||
D9AE0AD729187A700063488B /* SessionResetJobRecord.swift in Sources */,
|
||||
501050BB2EB959A4005161CA /* SessionStore.swift in Sources */,
|
||||
1700E34128BD41150073D949 /* SetAlgebra+SSK.swift in Sources */,
|
||||
502346772DB039320029DB97 /* SetDeque.swift in Sources */,
|
||||
66C2B14D2A13E2C7008DDE72 /* SgxWebsocketConfigurator.swift in Sources */,
|
||||
@ -19352,7 +19350,6 @@
|
||||
F9C5CC9A289453B300548EEE /* SignalService.pb.swift in Sources */,
|
||||
F9C5CCE2289453B300548EEE /* SignalServiceAddress.swift in Sources */,
|
||||
F9C5CDBB289453B400548EEE /* SignalServiceProfile.swift in Sources */,
|
||||
C1CD0E3A2A6B0D2700307F1A /* SignalSessionStore.swift in Sources */,
|
||||
72B0C2422C9EED0E00B57DAD /* SignedPreKeyRecord.swift in Sources */,
|
||||
F9C5CD33289453B300548EEE /* SignedPreKeyStoreImpl.swift in Sources */,
|
||||
F9C5CC51289453B300548EEE /* SMKError.swift in Sources */,
|
||||
@ -19371,7 +19368,6 @@
|
||||
F9C5CCA4289453B300548EEE /* SSKProto+OWS.swift in Sources */,
|
||||
F9C5CCA1289453B300548EEE /* SSKProto.swift in Sources */,
|
||||
F9C5CC8E289453B300548EEE /* SSKProtos.swift in Sources */,
|
||||
F9C5CD3C289453B300548EEE /* SSKSessionStore.swift in Sources */,
|
||||
F9C5CD9E289453B400548EEE /* SSKWebSocket.swift in Sources */,
|
||||
F9C5CC1B289453B300548EEE /* StickerError.swift in Sources */,
|
||||
F9C5CC13289453B300548EEE /* StickerInfo.m in Sources */,
|
||||
|
||||
@ -23,13 +23,13 @@ class DebugUISessionState: DebugUIPage {
|
||||
OWSTableItem(title: "Delete All Sessions", actionBlock: {
|
||||
SSKEnvironment.shared.databaseStorageRef.write { transaction in
|
||||
let sessionStore = DependenciesBridge.shared.signalProtocolStoreManager.signalProtocolStore(for: .aci).sessionStore
|
||||
sessionStore.deleteAllSessions(for: contactThread.contactAddress.serviceId!, tx: transaction)
|
||||
sessionStore.deleteSessions(forServiceId: contactThread.contactAddress.serviceId!, tx: transaction)
|
||||
}
|
||||
}),
|
||||
OWSTableItem(title: "Archive All Sessions", actionBlock: {
|
||||
SSKEnvironment.shared.databaseStorageRef.write { transaction in
|
||||
let sessionStore = DependenciesBridge.shared.signalProtocolStoreManager.signalProtocolStore(for: .aci).sessionStore
|
||||
sessionStore.archiveAllSessions(for: contactThread.contactAddress.serviceId!, tx: transaction)
|
||||
sessionStore.archiveSessions(forServiceId: contactThread.contactAddress.serviceId!, tx: transaction)
|
||||
}
|
||||
}),
|
||||
]
|
||||
|
||||
@ -135,8 +135,8 @@ public class ConversationInternalViewController: OWSTableViewController2 {
|
||||
let sessionSection = OWSTableSection()
|
||||
sessionSection.add(.actionItem(withText: "Delete Session") {
|
||||
SSKEnvironment.shared.databaseStorageRef.write { transaction in
|
||||
let aciStore = DependenciesBridge.shared.signalProtocolStoreManager.signalProtocolStore(for: .aci)
|
||||
aciStore.sessionStore.deleteAllSessions(for: address.serviceId!, tx: transaction)
|
||||
let sessionStore = DependenciesBridge.shared.signalProtocolStoreManager.signalProtocolStore(for: .aci).sessionStore
|
||||
sessionStore.deleteSessions(forServiceId: address.serviceId!, tx: transaction)
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@ -66,6 +66,7 @@ public class ProvisioningCoordinatorTest: XCTestCase {
|
||||
self.tsAccountManagerMock = .init()
|
||||
self.udManagerMock = .init()
|
||||
let preKeyStore = PreKeyStore()
|
||||
let sessionStore = SignalServiceKit.SessionStore()
|
||||
|
||||
self.provisioningCoordinator = ProvisioningCoordinatorImpl(
|
||||
chatConnectionManager: chatConnectionManagerMock,
|
||||
@ -81,9 +82,10 @@ public class ProvisioningCoordinatorTest: XCTestCase {
|
||||
registrationStateChangeManager: registrationStateChangeManagerMock,
|
||||
registrationWebSocketManager: MockRegistrationWebSocketManager(),
|
||||
signalProtocolStoreManager: SignalProtocolStoreManager(
|
||||
aciProtocolStore: .mock(identity: .aci, preKeyStore: preKeyStore),
|
||||
pniProtocolStore: .mock(identity: .pni, preKeyStore: preKeyStore),
|
||||
aciProtocolStore: .mock(identity: .aci, preKeyStore: preKeyStore, recipientIdFinder: recipientIdFinder, sessionStore: sessionStore),
|
||||
pniProtocolStore: .mock(identity: .pni, preKeyStore: preKeyStore, recipientIdFinder: recipientIdFinder, sessionStore: sessionStore),
|
||||
preKeyStore: preKeyStore,
|
||||
sessionStore: sessionStore,
|
||||
),
|
||||
signalService: signalServiceMock,
|
||||
storageServiceManager: storageServiceManagerMock,
|
||||
|
||||
@ -211,8 +211,7 @@ public class RegistrationStateChangeManagerImpl: RegistrationStateChangeManager
|
||||
tx: tx
|
||||
)
|
||||
|
||||
signalProtocolStoreManager.signalProtocolStore(for: .aci).sessionStore.resetSessionStore(tx: tx)
|
||||
signalProtocolStoreManager.signalProtocolStore(for: .pni).sessionStore.resetSessionStore(tx: tx)
|
||||
signalProtocolStoreManager.sessionStore.deleteAllSessions(tx: tx)
|
||||
senderKeyStore.resetSenderKeyStore(transaction: tx)
|
||||
udManager.removeSenderCertificates(tx: tx)
|
||||
versionedProfiles.clearProfileKeyCredentials(tx: tx)
|
||||
|
||||
@ -94,13 +94,13 @@ struct MergedRecipient {
|
||||
}
|
||||
|
||||
class RecipientMergerImpl: RecipientMerger {
|
||||
private let aciSessionStore: SignalSessionStore
|
||||
private let blockedRecipientStore: BlockedRecipientStore
|
||||
private let identityManager: OWSIdentityManager
|
||||
private let observers: Observers
|
||||
private let recipientDatabaseTable: RecipientDatabaseTable
|
||||
private let recipientFetcher: RecipientFetcher
|
||||
private let searchableNameIndexer: any SearchableNameIndexer
|
||||
private let sessionStore: SessionStore
|
||||
private let storageServiceManager: StorageServiceManager
|
||||
private let storyRecipientStore: StoryRecipientStore
|
||||
|
||||
@ -111,23 +111,23 @@ class RecipientMergerImpl: RecipientMerger {
|
||||
/// which we learned about the new association, and they are notified in the
|
||||
/// order in which they are provided.
|
||||
init(
|
||||
aciSessionStore: SignalSessionStore,
|
||||
blockedRecipientStore: BlockedRecipientStore,
|
||||
identityManager: OWSIdentityManager,
|
||||
observers: Observers,
|
||||
recipientDatabaseTable: RecipientDatabaseTable,
|
||||
recipientFetcher: RecipientFetcher,
|
||||
searchableNameIndexer: any SearchableNameIndexer,
|
||||
sessionStore: SessionStore,
|
||||
storageServiceManager: StorageServiceManager,
|
||||
storyRecipientStore: StoryRecipientStore
|
||||
) {
|
||||
self.aciSessionStore = aciSessionStore
|
||||
self.blockedRecipientStore = blockedRecipientStore
|
||||
self.identityManager = identityManager
|
||||
self.observers = observers
|
||||
self.recipientDatabaseTable = recipientDatabaseTable
|
||||
self.recipientFetcher = recipientFetcher
|
||||
self.searchableNameIndexer = searchableNameIndexer
|
||||
self.sessionStore = sessionStore
|
||||
self.storageServiceManager = storageServiceManager
|
||||
self.storyRecipientStore = storyRecipientStore
|
||||
}
|
||||
@ -801,7 +801,7 @@ class RecipientMergerImpl: RecipientMerger {
|
||||
for affectedRecipient in affectedRecipients {
|
||||
if affectedRecipient.isEmpty {
|
||||
// TODO: Should we clean up any more state related to the discarded recipient?
|
||||
aciSessionStore.mergeRecipient(affectedRecipient, into: mergedRecipient, tx: tx)
|
||||
sessionStore.mergeRecipientId(affectedRecipient.id, into: mergedRecipient.id, localIdentity: .aci, tx: tx)
|
||||
identityManager.mergeRecipient(affectedRecipient, into: mergedRecipient, tx: tx)
|
||||
blockedRecipientStore.mergeRecipientId(affectedRecipient.id, into: mergedRecipient.id, tx: tx)
|
||||
failIfThrows { try storyRecipientStore.mergeRecipient(affectedRecipient, into: mergedRecipient, tx: tx) }
|
||||
@ -860,7 +860,7 @@ class RecipientMergerImpl: RecipientMerger {
|
||||
intoValue: newRecipient.isEmpty ? mergedRecipient : newRecipient
|
||||
)
|
||||
|
||||
guard aciSessionStore.mightContainSession(for: recipientPair.fromValue, tx: tx) else {
|
||||
guard sessionStore.hasSessionRecords(forRecipientId: recipientPair.fromValue.id, localIdentity: .aci, tx: tx) else {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -894,7 +894,7 @@ class RecipientMergerImpl: RecipientMerger {
|
||||
// the session/identity for these recipients.
|
||||
if recipientPair.fromValue.uniqueId == recipientPair.intoValue.uniqueId {
|
||||
identityManager.removeRecipientIdentity(for: recipientPair.fromValue.uniqueId, tx: tx)
|
||||
aciSessionStore.deleteAllSessions(for: recipientPair.fromValue.uniqueId, tx: tx)
|
||||
sessionStore.deleteSessions(forRecipientId: recipientPair.fromValue.id, localIdentity: .aci, tx: tx)
|
||||
}
|
||||
|
||||
// The canonical case is adding an ACI to a recipient that already had a
|
||||
|
||||
@ -527,7 +527,6 @@ public class SentMessageTranscriptReceiverImpl: SentMessageTranscriptReceiver {
|
||||
}
|
||||
|
||||
private func archiveSessions(for address: SignalServiceAddress, tx: DBWriteTransaction) {
|
||||
let sessionStore = signalProtocolStoreManager.signalProtocolStore(for: .aci).sessionStore
|
||||
sessionStore.archiveAllSessions(for: address, tx: tx)
|
||||
self.signalProtocolStoreManager.signalProtocolStore(for: .aci).sessionStore.archiveSessions(forAddress: address, tx: tx)
|
||||
}
|
||||
}
|
||||
|
||||
@ -276,12 +276,14 @@ extension AppSetup.GlobalsContinuation {
|
||||
)
|
||||
|
||||
let preKeyStore = PreKeyStore()
|
||||
let sessionStore = SessionStore()
|
||||
|
||||
let aciProtocolStore = SignalProtocolStore.build(
|
||||
dateProvider: dateProvider,
|
||||
identity: .aci,
|
||||
preKeyStore: preKeyStore,
|
||||
recipientIdFinder: recipientIdFinder,
|
||||
sessionStore: sessionStore,
|
||||
)
|
||||
let blockedRecipientStore = BlockedRecipientStore()
|
||||
let blockingManager = BlockingManager(
|
||||
@ -305,6 +307,7 @@ extension AppSetup.GlobalsContinuation {
|
||||
identity: .pni,
|
||||
preKeyStore: preKeyStore,
|
||||
recipientIdFinder: recipientIdFinder,
|
||||
sessionStore: sessionStore,
|
||||
)
|
||||
let profileManager = testDependencies.profileManager ?? OWSProfileManager(
|
||||
appReadiness: appReadiness,
|
||||
@ -320,6 +323,7 @@ extension AppSetup.GlobalsContinuation {
|
||||
aciProtocolStore: aciProtocolStore,
|
||||
pniProtocolStore: pniProtocolStore,
|
||||
preKeyStore: preKeyStore,
|
||||
sessionStore: sessionStore,
|
||||
)
|
||||
let signalService = testDependencies.signalService ?? OWSSignalService(libsignalNet: libsignalNet)
|
||||
let signalServiceAddressCache = SignalServiceAddressCache()
|
||||
@ -764,7 +768,6 @@ extension AppSetup.GlobalsContinuation {
|
||||
let badgeCountFetcher = BadgeCountFetcherImpl()
|
||||
|
||||
let identityManager = OWSIdentityManagerImpl(
|
||||
aciProtocolStore: aciProtocolStore,
|
||||
appReadiness: appReadiness,
|
||||
db: db,
|
||||
messageSenderJobQueue: messageSenderJobQueue,
|
||||
@ -775,6 +778,7 @@ extension AppSetup.GlobalsContinuation {
|
||||
recipientDatabaseTable: recipientDatabaseTable,
|
||||
recipientFetcher: recipientFetcher,
|
||||
recipientIdFinder: recipientIdFinder,
|
||||
sessionStore: sessionStore,
|
||||
storageServiceManager: storageServiceManager,
|
||||
tsAccountManager: tsAccountManager
|
||||
)
|
||||
@ -1036,7 +1040,6 @@ extension AppSetup.GlobalsContinuation {
|
||||
|
||||
let authorMergeHelper = AuthorMergeHelper()
|
||||
let recipientMerger = RecipientMergerImpl(
|
||||
aciSessionStore: aciProtocolStore.sessionStore,
|
||||
blockedRecipientStore: blockedRecipientStore,
|
||||
identityManager: identityManager,
|
||||
observers: RecipientMergerImpl.buildObservers(
|
||||
@ -1063,6 +1066,7 @@ extension AppSetup.GlobalsContinuation {
|
||||
recipientDatabaseTable: recipientDatabaseTable,
|
||||
recipientFetcher: recipientFetcher,
|
||||
searchableNameIndexer: searchableNameIndexer,
|
||||
sessionStore: sessionStore,
|
||||
storageServiceManager: storageServiceManager,
|
||||
storyRecipientStore: storyRecipientStore
|
||||
)
|
||||
|
||||
@ -71,6 +71,11 @@ public enum BuildFlags {
|
||||
// that's now dead because this is false.
|
||||
public static let decodeDeprecatedPreKeys = true
|
||||
|
||||
// Turn this off after all still-registered clients have run this
|
||||
// migration. That should happen by 2026-08-04. Then, delete all the code
|
||||
// that's now dead because this is false.
|
||||
public static let migrateDeprecatedSessions = true
|
||||
|
||||
public static let serviceIdBinaryProvisioning = true
|
||||
public static let serviceIdBinaryConstantOverhead = !serviceIdStrings || (build <= .internal)
|
||||
public static let serviceIdBinaryVariableOverhead = !serviceIdStrings || (build <= .dev)
|
||||
|
||||
@ -106,6 +106,6 @@ private class SessionResetJobRunner: JobRunner {
|
||||
|
||||
private func archiveAllSessions(for contactThread: TSContactThread, tx: DBWriteTransaction) {
|
||||
let sessionStore = DependenciesBridge.shared.signalProtocolStoreManager.signalProtocolStore(for: .aci).sessionStore
|
||||
sessionStore.archiveAllSessions(for: contactThread.contactAddress, tx: tx)
|
||||
sessionStore.archiveSessions(forAddress: contactThread.contactAddress, tx: tx)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1907,7 +1907,7 @@ public final class MessageReceiver {
|
||||
let sessionRecord = try sessionStore.loadSession(for: protocolAddress, context: tx)
|
||||
if try sessionRecord?.currentRatchetKeyMatches(ratchetKey) == true {
|
||||
Logger.info("Decryption error included ratchet key. Archiving...")
|
||||
sessionStore.archiveSession(for: sourceAci, deviceId: sourceDeviceId, tx: tx)
|
||||
sessionStore.archiveSession(forServiceId: sourceAci, deviceId: sourceDeviceId, tx: tx)
|
||||
didPerformSessionReset = true
|
||||
} else {
|
||||
didPerformSessionReset = false
|
||||
@ -2188,7 +2188,7 @@ public final class MessageReceiver {
|
||||
TSInfoMessage(thread: thread, messageType: .typeRemoteUserEndedSession).anyInsert(transaction: tx)
|
||||
|
||||
let sessionStore = DependenciesBridge.shared.signalProtocolStoreManager.signalProtocolStore(for: .aci).sessionStore
|
||||
sessionStore.archiveAllSessions(for: decryptedEnvelope.sourceAci, tx: tx)
|
||||
sessionStore.archiveSessions(forServiceId: decryptedEnvelope.sourceAci, tx: tx)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -48,10 +48,10 @@ public class MessageSender {
|
||||
|
||||
// MARK: - Creating Signal Protocol Sessions
|
||||
|
||||
private func validSession(for serviceId: ServiceId, deviceId: DeviceId, tx: DBReadTransaction) throws -> SessionRecord? {
|
||||
private func validSession(for serviceId: ServiceId, deviceId: DeviceId, tx: DBReadTransaction) throws -> LibSignalClient.SessionRecord? {
|
||||
let sessionStore = DependenciesBridge.shared.signalProtocolStoreManager.signalProtocolStore(for: .aci).sessionStore
|
||||
do {
|
||||
guard let session = try sessionStore.loadSession(for: serviceId, deviceId: deviceId, tx: tx) else {
|
||||
guard let session = try sessionStore.loadSession(forServiceId: serviceId, deviceId: deviceId, tx: tx) else {
|
||||
return nil
|
||||
}
|
||||
guard session.hasCurrentState else {
|
||||
@ -1555,7 +1555,7 @@ public class MessageSender {
|
||||
Logger.warn("Found identity key mismatch on outgoing message to \(serviceId).\(deviceId). Archiving session before retrying...")
|
||||
let signalProtocolStoreManager = DependenciesBridge.shared.signalProtocolStoreManager
|
||||
let aciSessionStore = signalProtocolStoreManager.signalProtocolStore(for: .aci).sessionStore
|
||||
aciSessionStore.archiveSession(for: serviceId, deviceId: deviceId, tx: tx)
|
||||
aciSessionStore.archiveSession(forServiceId: serviceId, deviceId: deviceId, tx: tx)
|
||||
throw OWSRetryableMessageSenderError()
|
||||
} catch SignalError.untrustedIdentity {
|
||||
Logger.warn("Found untrusted identity on outgoing message to \(serviceId). Wrapping error and throwing...")
|
||||
@ -1755,7 +1755,7 @@ public class MessageSender {
|
||||
Logger.warn("Stale devices for \(serviceId): \(staleDevices)")
|
||||
let sessionStore = DependenciesBridge.shared.signalProtocolStoreManager.signalProtocolStore(for: .aci).sessionStore
|
||||
for staleDeviceId in staleDevices {
|
||||
sessionStore.archiveSession(for: serviceId, deviceId: staleDeviceId, tx: tx)
|
||||
sessionStore.archiveSession(forServiceId: serviceId, deviceId: staleDeviceId, tx: tx)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1809,7 +1809,7 @@ public class MessageSender {
|
||||
Logger.info("Archiving sessions for extra devices: \(devicesToRemove)")
|
||||
let sessionStore = DependenciesBridge.shared.signalProtocolStoreManager.signalProtocolStore(for: .aci).sessionStore
|
||||
for deviceId in devicesToRemove {
|
||||
sessionStore.archiveSession(for: serviceId, deviceId: deviceId, tx: tx)
|
||||
sessionStore.archiveSession(forServiceId: serviceId, deviceId: deviceId, tx: tx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -189,7 +189,6 @@ private extension OWSIdentity {
|
||||
}
|
||||
|
||||
public class OWSIdentityManagerImpl: OWSIdentityManager {
|
||||
private let aciProtocolStore: SignalProtocolStore
|
||||
private let appReadiness: AppReadiness
|
||||
private let db: any DB
|
||||
private let messageSenderJobQueue: MessageSenderJobQueue
|
||||
@ -202,12 +201,12 @@ public class OWSIdentityManagerImpl: OWSIdentityManager {
|
||||
private let recipientDatabaseTable: RecipientDatabaseTable
|
||||
private let recipientFetcher: RecipientFetcher
|
||||
private let recipientIdFinder: RecipientIdFinder
|
||||
private let sessionStore: SessionStore
|
||||
private let shareMyPhoneNumberStore: KeyValueStore
|
||||
private let storageServiceManager: StorageServiceManager
|
||||
private let tsAccountManager: TSAccountManager
|
||||
|
||||
public init(
|
||||
aciProtocolStore: SignalProtocolStore,
|
||||
init(
|
||||
appReadiness: AppReadiness,
|
||||
db: any DB,
|
||||
messageSenderJobQueue: MessageSenderJobQueue,
|
||||
@ -218,10 +217,10 @@ public class OWSIdentityManagerImpl: OWSIdentityManager {
|
||||
recipientDatabaseTable: RecipientDatabaseTable,
|
||||
recipientFetcher: RecipientFetcher,
|
||||
recipientIdFinder: RecipientIdFinder,
|
||||
sessionStore: SessionStore,
|
||||
storageServiceManager: StorageServiceManager,
|
||||
tsAccountManager: TSAccountManager
|
||||
) {
|
||||
self.aciProtocolStore = aciProtocolStore
|
||||
self.appReadiness = appReadiness
|
||||
self.db = db
|
||||
self.messageSenderJobQueue = messageSenderJobQueue
|
||||
@ -238,6 +237,7 @@ public class OWSIdentityManagerImpl: OWSIdentityManager {
|
||||
self.recipientDatabaseTable = recipientDatabaseTable
|
||||
self.recipientFetcher = recipientFetcher
|
||||
self.recipientIdFinder = recipientIdFinder
|
||||
self.sessionStore = sessionStore
|
||||
self.shareMyPhoneNumberStore = KeyValueStore(
|
||||
collection: "OWSIdentityManager.shareMyPhoneNumberStore"
|
||||
)
|
||||
@ -348,27 +348,27 @@ public class OWSIdentityManagerImpl: OWSIdentityManager {
|
||||
|
||||
@discardableResult
|
||||
public func saveIdentityKey(_ identityKey: Data, for serviceId: ServiceId, tx: DBWriteTransaction) -> Result<IdentityChange, RecipientIdError> {
|
||||
let recipientIdResult = recipientIdFinder.ensureRecipientUniqueId(for: serviceId, tx: tx)
|
||||
return recipientIdResult.map({ _saveIdentityKey(identityKey, for: serviceId, recipientUniqueId: $0, tx: tx) })
|
||||
let recipientResult = recipientIdFinder.ensureRecipient(for: serviceId, tx: tx)
|
||||
return recipientResult.map({ _saveIdentityKey(identityKey, for: serviceId, recipient: $0, tx: tx) })
|
||||
}
|
||||
|
||||
private func _saveIdentityKey(_ identityKey: Data, for serviceId: ServiceId, recipientUniqueId: RecipientUniqueId, tx: DBWriteTransaction) -> IdentityChange {
|
||||
private func _saveIdentityKey(_ identityKey: Data, for serviceId: ServiceId, recipient: SignalRecipient, tx: DBWriteTransaction) -> IdentityChange {
|
||||
owsAssertDebug(identityKey.count == Constants.storedIdentityKeyLength)
|
||||
|
||||
let existingIdentity = OWSRecipientIdentity.anyFetch(uniqueId: recipientUniqueId, transaction: tx)
|
||||
let existingIdentity = OWSRecipientIdentity.anyFetch(uniqueId: recipient.uniqueId, transaction: tx)
|
||||
guard let existingIdentity else {
|
||||
Logger.info("Saving first-use identity for \(serviceId)")
|
||||
OWSRecipientIdentity(
|
||||
uniqueId: recipientUniqueId,
|
||||
uniqueId: recipient.uniqueId,
|
||||
identityKey: identityKey,
|
||||
isFirstKnownKey: true,
|
||||
createdAt: Date(),
|
||||
verificationState: .default
|
||||
).anyInsert(transaction: tx)
|
||||
// Cancel any pending verification state sync messages for this recipient.
|
||||
clearSyncMessage(for: recipientUniqueId, tx: tx)
|
||||
clearSyncMessage(for: recipient.uniqueId, tx: tx)
|
||||
fireIdentityStateChangeNotification(after: tx)
|
||||
storageServiceManager.recordPendingUpdates(updatedRecipientUniqueIds: [recipientUniqueId])
|
||||
storageServiceManager.recordPendingUpdates(updatedRecipientUniqueIds: [recipient.uniqueId])
|
||||
return .newOrUnchanged
|
||||
}
|
||||
|
||||
@ -386,16 +386,16 @@ public class OWSIdentityManagerImpl: OWSIdentityManager {
|
||||
Logger.info("Saving new identity for \(serviceId): \(existingIdentity.verificationState) -> \(verificationState)")
|
||||
insertIdentityChangeInfoMessage(for: serviceId, wasIdentityVerified: existingIdentity.wasIdentityVerified, tx: tx)
|
||||
OWSRecipientIdentity(
|
||||
uniqueId: recipientUniqueId,
|
||||
uniqueId: recipient.uniqueId,
|
||||
identityKey: identityKey,
|
||||
isFirstKnownKey: false,
|
||||
createdAt: Date(),
|
||||
verificationState: verificationState.rawValue
|
||||
).anyUpsert(transaction: tx)
|
||||
aciProtocolStore.sessionStore.archiveAllSessions(for: serviceId, tx: tx)
|
||||
sessionStore.archiveSessions(forRecipientId: recipient.id, localIdentity: .aci, tx: tx)
|
||||
// Cancel any pending verification state sync messages for this recipient.
|
||||
clearSyncMessage(for: recipientUniqueId, tx: tx)
|
||||
storageServiceManager.recordPendingUpdates(updatedRecipientUniqueIds: [recipientUniqueId])
|
||||
clearSyncMessage(for: recipient.uniqueId, tx: tx)
|
||||
storageServiceManager.recordPendingUpdates(updatedRecipientUniqueIds: [recipient.uniqueId])
|
||||
return .replacedExisting
|
||||
}
|
||||
|
||||
|
||||
@ -364,8 +364,7 @@ public class OWSMessageDecrypter {
|
||||
|
||||
Logger.warn("Archiving session for undecryptable message from \(senderId)")
|
||||
let sessionStore = DependenciesBridge.shared.signalProtocolStoreManager.signalProtocolStore(for: .aci).sessionStore
|
||||
sessionStore.archiveSession(for: sourceAci, deviceId: sourceDeviceId, tx: transaction)
|
||||
|
||||
sessionStore.archiveSession(forServiceId: sourceAci, deviceId: sourceDeviceId, tx: transaction)
|
||||
trySendNullMessage(in: contactThread, senderId: senderId, transaction: transaction)
|
||||
return true
|
||||
} else {
|
||||
|
||||
@ -77,7 +77,7 @@ fileprivate extension SMKMessageType {
|
||||
|
||||
@objc
|
||||
public class SMKSecretSessionCipher: NSObject {
|
||||
private let currentSessionStore: SessionStore
|
||||
private let currentSessionStore: LibSignalClient.SessionStore
|
||||
private let currentPreKeyStore: LibSignalClient.PreKeyStore
|
||||
private let currentSignedPreKeyStore: SignedPreKeyStore
|
||||
private let currentKyberPreKeyStore: KyberPreKeyStore
|
||||
@ -86,7 +86,7 @@ public class SMKSecretSessionCipher: NSObject {
|
||||
|
||||
// public SecretSessionCipher(SignalProtocolStore signalProtocolStore) {
|
||||
init(
|
||||
sessionStore: SessionStore,
|
||||
sessionStore: LibSignalClient.SessionStore,
|
||||
preKeyStore: LibSignalClient.PreKeyStore,
|
||||
signedPreKeyStore: SignedPreKeyStore,
|
||||
kyberPreKeyStore: KyberPreKeyStore,
|
||||
|
||||
@ -9,9 +9,9 @@ import Foundation
|
||||
import LibSignalClient
|
||||
|
||||
extension SignalProtocolStore {
|
||||
static func mock(identity: OWSIdentity, preKeyStore: PreKeyStore) -> Self {
|
||||
static func mock(identity: OWSIdentity, preKeyStore: PreKeyStore, recipientIdFinder: RecipientIdFinder, sessionStore: SessionStore) -> Self {
|
||||
return SignalProtocolStore(
|
||||
sessionStore: MockSessionStore(),
|
||||
sessionStore: SessionManagerForIdentity(identity: identity, recipientIdFinder: recipientIdFinder, sessionStore: sessionStore),
|
||||
preKeyStore: PreKeyStoreImpl(for: identity, preKeyStore: preKeyStore),
|
||||
signedPreKeyStore: SignedPreKeyStoreImpl(for: identity, preKeyStore: preKeyStore),
|
||||
kyberPreKeyStore: KyberPreKeyStoreImpl(for: identity, dateProvider: Date.provider, preKeyStore: preKeyStore),
|
||||
@ -19,20 +19,4 @@ extension SignalProtocolStore {
|
||||
}
|
||||
}
|
||||
|
||||
class MockSessionStore: SignalSessionStore {
|
||||
func mightContainSession(for recipient: SignalRecipient, tx: DBReadTransaction) -> Bool { false }
|
||||
func mergeRecipient(_ recipient: SignalRecipient, into targetRecipient: SignalRecipient, tx: DBWriteTransaction) { }
|
||||
func archiveAllSessions(for serviceId: ServiceId, tx: DBWriteTransaction) { }
|
||||
func archiveAllSessions(for address: SignalServiceAddress, tx: DBWriteTransaction) { }
|
||||
func archiveSession(for serviceId: ServiceId, deviceId: DeviceId, tx: DBWriteTransaction) { }
|
||||
func loadSession(for serviceId: ServiceId, deviceId: DeviceId, tx: DBReadTransaction) throws -> LibSignalClient.SessionRecord? { nil }
|
||||
func loadSession(for address: ProtocolAddress, context: StoreContext) throws -> LibSignalClient.SessionRecord? { nil }
|
||||
func resetSessionStore(tx: DBWriteTransaction) { }
|
||||
func deleteAllSessions(for serviceId: ServiceId, tx: DBWriteTransaction) { }
|
||||
func deleteAllSessions(for recipientUniqueId: RecipientUniqueId, tx: DBWriteTransaction) { }
|
||||
func removeAll(tx: DBWriteTransaction) { }
|
||||
func loadExistingSessions(for addresses: [ProtocolAddress], context: StoreContext) throws -> [LibSignalClient.SessionRecord] { [] }
|
||||
func storeSession(_ record: LibSignalClient.SessionRecord, for address: ProtocolAddress, context: StoreContext) throws { }
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@ -1,311 +0,0 @@
|
||||
//
|
||||
// Copyright 2020 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
public import LibSignalClient
|
||||
|
||||
public final class SSKSessionStore: SignalSessionStore {
|
||||
// Note that even though the values here are always serialized Data,
|
||||
// using AnyObject here defers any checking or conversion of the values
|
||||
// when converting from an NSDictionary.
|
||||
fileprivate typealias SessionsByDeviceDictionary = [Int32: AnyObject]
|
||||
|
||||
private let keyValueStore: KeyValueStore
|
||||
private let recipientIdFinder: RecipientIdFinder
|
||||
|
||||
public init(
|
||||
for identity: OWSIdentity,
|
||||
recipientIdFinder: RecipientIdFinder
|
||||
) {
|
||||
self.keyValueStore = KeyValueStore(collection: {
|
||||
switch identity {
|
||||
case .aci:
|
||||
return "TSStorageManagerSessionStoreCollection"
|
||||
case .pni:
|
||||
return "TSStorageManagerPNISessionStoreCollection"
|
||||
}
|
||||
}())
|
||||
self.recipientIdFinder = recipientIdFinder
|
||||
}
|
||||
|
||||
fileprivate func loadSerializedSession(
|
||||
for serviceId: ServiceId,
|
||||
deviceId: UInt32,
|
||||
tx: DBReadTransaction
|
||||
) throws -> Data? {
|
||||
switch recipientIdFinder.recipientUniqueId(for: serviceId, tx: tx) {
|
||||
case .none:
|
||||
return nil
|
||||
case .some(.success(let recipientUniqueId)):
|
||||
return loadSerializedSession(for: recipientUniqueId, deviceId: deviceId, tx: tx)
|
||||
case .some(.failure(let error)):
|
||||
switch error {
|
||||
case .mustNotUsePniBecauseAciExists:
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func serializedSession(fromDatabaseRepresentation entry: Any) -> Data? {
|
||||
switch entry {
|
||||
case let data as Data:
|
||||
return data
|
||||
default:
|
||||
owsFailDebug("unexpected entry in session store: \(type(of: entry))")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
private func loadSerializedSession(
|
||||
for recipientUniqueId: String,
|
||||
deviceId: UInt32,
|
||||
tx: DBReadTransaction
|
||||
) -> Data? {
|
||||
owsAssertDebug(!recipientUniqueId.isEmpty)
|
||||
owsAssertDebug(deviceId > 0)
|
||||
|
||||
let dictionary = loadAllSerializedSessions(for: recipientUniqueId, tx: tx)
|
||||
guard let entry = dictionary?[Int32(bitPattern: deviceId)] else {
|
||||
return nil
|
||||
}
|
||||
return serializedSession(fromDatabaseRepresentation: entry)
|
||||
}
|
||||
|
||||
private func loadAllSerializedSessions(
|
||||
for recipientUniqueId: String,
|
||||
tx: DBReadTransaction
|
||||
) -> SessionsByDeviceDictionary? {
|
||||
owsAssertDebug(!recipientUniqueId.isEmpty)
|
||||
|
||||
guard let serialized = keyValueStore.getData(recipientUniqueId, transaction: tx) else {
|
||||
return nil
|
||||
}
|
||||
|
||||
let rawDictionary: NSDictionary?
|
||||
do {
|
||||
rawDictionary = try NSKeyedUnarchiver.unarchivedObject(ofClasses: [NSDictionary.self, NSNumber.self, NSData.self], from: serialized) as? NSDictionary
|
||||
} catch let error as NSError {
|
||||
// Deliberately don't log the full error; it might contain session data.
|
||||
Logger.error("Unknown data (or legacy session) in session store; continuing as if there were no stored sessions (\(error.domain) \(error.code))")
|
||||
return nil
|
||||
}
|
||||
|
||||
guard let dictionary = rawDictionary as? SessionsByDeviceDictionary else {
|
||||
Logger.error("Invalid device ID keys in session store; continuing as if there were no stored sessions")
|
||||
return nil
|
||||
}
|
||||
|
||||
return dictionary
|
||||
}
|
||||
|
||||
fileprivate func storeSerializedSession(
|
||||
_ sessionData: Data,
|
||||
for serviceId: ServiceId,
|
||||
deviceId: UInt32,
|
||||
tx: DBWriteTransaction
|
||||
) throws {
|
||||
switch recipientIdFinder.ensureRecipientUniqueId(for: serviceId, tx: tx) {
|
||||
case .failure(let error):
|
||||
switch error {
|
||||
case .mustNotUsePniBecauseAciExists:
|
||||
throw error
|
||||
}
|
||||
case .success(let recipientUniqueId):
|
||||
storeSerializedSession(for: recipientUniqueId, deviceId: deviceId, sessionData: sessionData, tx: tx)
|
||||
}
|
||||
}
|
||||
|
||||
private func storeSerializedSession(
|
||||
for recipientUniqueId: String,
|
||||
deviceId: UInt32,
|
||||
sessionData: Data,
|
||||
tx: DBWriteTransaction
|
||||
) {
|
||||
owsAssertDebug(!recipientUniqueId.isEmpty)
|
||||
owsAssertDebug(deviceId > 0)
|
||||
|
||||
var dictionary = loadAllSerializedSessions(for: recipientUniqueId, tx: tx) ?? [:]
|
||||
dictionary[Int32(bitPattern: deviceId)] = sessionData as NSData
|
||||
saveSerializedSessions(dictionary, for: recipientUniqueId, tx: tx)
|
||||
}
|
||||
|
||||
private func saveSerializedSessions(
|
||||
_ sessions: SessionsByDeviceDictionary,
|
||||
for recipientUniqueId: String,
|
||||
tx: DBWriteTransaction
|
||||
) {
|
||||
// Avoid using KeyValueStore.setObject(_:key:transaction:).
|
||||
// The database-based KV store implicitly archives using NSKeyedArchiver,
|
||||
// but the in-memory one for testing does not.
|
||||
// In order for loadAllSerializedSessions(for:tx:) to manually control deserialization,
|
||||
// we need to consistently archive.
|
||||
// This will also make it easier to potentially move away from NSKeyedArchiver in the future.
|
||||
do {
|
||||
let archived = try NSKeyedArchiver.archivedData(withRootObject: sessions, requiringSecureCoding: true)
|
||||
keyValueStore.setData(archived, key: recipientUniqueId, transaction: tx)
|
||||
} catch {
|
||||
Logger.debug("failed to serialize session data: \(error)\n\(sessions)")
|
||||
owsFailDebug("failed to serialize session data")
|
||||
// At least clear out whatever's in the store, so we don't keep old sessions around longer than we should.
|
||||
keyValueStore.setData(nil, key: recipientUniqueId, transaction: tx)
|
||||
}
|
||||
}
|
||||
|
||||
public func mightContainSession(for recipient: SignalRecipient, tx: DBReadTransaction) -> Bool {
|
||||
return keyValueStore.hasValue(recipient.uniqueId, transaction: tx)
|
||||
}
|
||||
|
||||
public func mergeRecipient(_ recipient: SignalRecipient, into targetRecipient: SignalRecipient, tx: DBWriteTransaction) {
|
||||
let recipientPair = MergePair(fromValue: recipient, intoValue: targetRecipient)
|
||||
let sessionBlob = recipientPair.map { keyValueStore.getData($0.uniqueId, transaction: tx) }
|
||||
guard let fromValue = sessionBlob.fromValue else {
|
||||
return
|
||||
}
|
||||
if sessionBlob.intoValue == nil {
|
||||
keyValueStore.setData(fromValue, key: targetRecipient.uniqueId, transaction: tx)
|
||||
}
|
||||
keyValueStore.removeValue(forKey: recipient.uniqueId, transaction: tx)
|
||||
}
|
||||
|
||||
public func deleteAllSessions(for serviceId: ServiceId, tx: DBWriteTransaction) {
|
||||
Logger.info("deleting all sessions for \(serviceId)")
|
||||
switch recipientIdFinder.recipientUniqueId(for: serviceId, tx: tx) {
|
||||
case .none, .some(.failure(.mustNotUsePniBecauseAciExists)):
|
||||
// There can't possibly be any sessions that need to be deleted.
|
||||
return
|
||||
case .some(.success(let recipientUniqueId)):
|
||||
owsAssertDebug(!recipientUniqueId.isEmpty)
|
||||
deleteAllSessions(for: recipientUniqueId, tx: tx)
|
||||
}
|
||||
}
|
||||
|
||||
public func deleteAllSessions(for recipientUniqueId: RecipientUniqueId, tx: DBWriteTransaction) {
|
||||
keyValueStore.removeValue(forKey: recipientUniqueId, transaction: tx)
|
||||
}
|
||||
|
||||
public func archiveAllSessions(for serviceId: ServiceId, tx: DBWriteTransaction) {
|
||||
Logger.info("archiving all sessions for \(serviceId)")
|
||||
switch recipientIdFinder.recipientUniqueId(for: serviceId, tx: tx) {
|
||||
case .none, .some(.failure(.mustNotUsePniBecauseAciExists)):
|
||||
// There can't possibly be any sessions that need to be archived.
|
||||
return
|
||||
case .some(.success(let recipientUniqueId)):
|
||||
archiveAllSessions(for: recipientUniqueId, tx: tx)
|
||||
}
|
||||
}
|
||||
|
||||
public func archiveAllSessions(for address: SignalServiceAddress, tx: DBWriteTransaction) {
|
||||
Logger.info("archiving all sessions for \(address)")
|
||||
switch recipientIdFinder.recipientUniqueId(for: address, tx: tx) {
|
||||
case .none, .some(.failure(.mustNotUsePniBecauseAciExists)):
|
||||
// There can't possibly be any sessions that need to be archived.
|
||||
return
|
||||
case .some(.success(let recipientUniqueId)):
|
||||
archiveAllSessions(for: recipientUniqueId, tx: tx)
|
||||
}
|
||||
}
|
||||
|
||||
private func archiveAllSessions(for recipientUniqueId: RecipientUniqueId, tx: DBWriteTransaction) {
|
||||
owsAssertDebug(!recipientUniqueId.isEmpty)
|
||||
|
||||
guard let dictionary = loadAllSerializedSessions(for: recipientUniqueId, tx: tx) else {
|
||||
// We never had a session for this account in the first place.
|
||||
return
|
||||
}
|
||||
|
||||
let newDictionary: SessionsByDeviceDictionary = dictionary.mapValues { record in
|
||||
guard let data = serializedSession(fromDatabaseRepresentation: record) else {
|
||||
// We've already logged an error; skip this session.
|
||||
return record
|
||||
}
|
||||
|
||||
do {
|
||||
let session = try SessionRecord(bytes: data)
|
||||
session.archiveCurrentState()
|
||||
return session.serialize() as NSData
|
||||
} catch {
|
||||
owsFailDebug("\(error)")
|
||||
return record
|
||||
}
|
||||
}
|
||||
|
||||
saveSerializedSessions(newDictionary, for: recipientUniqueId, tx: tx)
|
||||
}
|
||||
|
||||
public func resetSessionStore(tx: DBWriteTransaction) {
|
||||
Logger.warn("resetting session store")
|
||||
keyValueStore.removeAll(transaction: tx)
|
||||
}
|
||||
|
||||
public func removeAll(tx: DBWriteTransaction) {
|
||||
keyValueStore.removeAll(transaction: tx)
|
||||
}
|
||||
}
|
||||
|
||||
extension SSKSessionStore {
|
||||
public func loadSession(
|
||||
for serviceId: ServiceId,
|
||||
deviceId: DeviceId,
|
||||
tx: DBReadTransaction
|
||||
) throws -> SessionRecord? {
|
||||
guard let serializedData = try loadSerializedSession(for: serviceId, deviceId: deviceId.uint32Value, tx: tx) else {
|
||||
return nil
|
||||
}
|
||||
return try SessionRecord(bytes: serializedData)
|
||||
}
|
||||
|
||||
fileprivate func storeSession(
|
||||
_ record: SessionRecord,
|
||||
for serviceId: ServiceId,
|
||||
deviceId: UInt32,
|
||||
tx: DBWriteTransaction
|
||||
) throws {
|
||||
try storeSerializedSession(record.serialize(), for: serviceId, deviceId: deviceId, tx: tx)
|
||||
}
|
||||
|
||||
public func archiveSession(for serviceId: ServiceId, deviceId: DeviceId, tx: DBWriteTransaction) {
|
||||
do {
|
||||
guard let session = try loadSession(for: serviceId, deviceId: deviceId, tx: tx) else {
|
||||
return
|
||||
}
|
||||
session.archiveCurrentState()
|
||||
try storeSession(session, for: serviceId, deviceId: deviceId.uint32Value, tx: tx)
|
||||
} catch {
|
||||
owsFailDebug("\(error)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extension SSKSessionStore: LibSignalClient.SessionStore {
|
||||
public func loadSession(for address: ProtocolAddress, context: StoreContext) throws -> SessionRecord? {
|
||||
return try loadSession(for: address.serviceId, deviceId: address.deviceIdObj, tx: context.asTransaction)
|
||||
}
|
||||
|
||||
public func loadExistingSessions(
|
||||
for addresses: [ProtocolAddress],
|
||||
context: StoreContext
|
||||
) throws -> [SessionRecord] {
|
||||
return try addresses.map { address in
|
||||
guard let session = try loadSession(for: address, context: context) else {
|
||||
throw SignalError.sessionNotFound("\(address)")
|
||||
}
|
||||
return session
|
||||
}
|
||||
}
|
||||
|
||||
public func storeSession(_ record: SessionRecord, for address: ProtocolAddress, context: StoreContext) throws {
|
||||
try storeSession(record, for: address.serviceId, deviceId: address.deviceId, tx: context.asTransaction)
|
||||
}
|
||||
}
|
||||
|
||||
#if TESTABLE_BUILD
|
||||
|
||||
extension SSKSessionStore {
|
||||
// Available through `@testable import`
|
||||
internal var keyValueStoreForTesting: KeyValueStore {
|
||||
self.keyValueStore
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
@ -82,7 +82,7 @@ public class SenderKeyStore {
|
||||
// comparing a set of (deviceId, registrationId) structs, we should be able
|
||||
// to detect reused deviceIds that will need an SKDM.
|
||||
let registrationId = try sessionStore.loadSession(
|
||||
for: serviceId,
|
||||
forServiceId: serviceId,
|
||||
deviceId: deviceId,
|
||||
tx: tx
|
||||
)?.remoteRegistrationId()
|
||||
|
||||
337
SignalServiceKit/Storage/AxolotlStore/SessionStore.swift
Normal file
337
SignalServiceKit/Storage/AxolotlStore/SessionStore.swift
Normal file
@ -0,0 +1,337 @@
|
||||
//
|
||||
// Copyright 2025 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
import Foundation
|
||||
import GRDB
|
||||
public import LibSignalClient
|
||||
|
||||
struct SessionRecord: Codable, FetchableRecord, PersistableRecord {
|
||||
static let databaseTableName: String = "Session"
|
||||
|
||||
let id: Int64
|
||||
var recipientId: SignalRecipient.RowId
|
||||
let localIdentity: OWSIdentity
|
||||
let deviceId: DeviceId
|
||||
/// May be nil if there was a legacy session.
|
||||
var serializedRecord: Data?
|
||||
|
||||
enum CodingKeys: String, CodingKey {
|
||||
case id
|
||||
case recipientId
|
||||
case localIdentity
|
||||
case deviceId
|
||||
case serializedRecord
|
||||
}
|
||||
|
||||
enum Columns {
|
||||
static let recipientId = Column(CodingKeys.recipientId.rawValue)
|
||||
static let localIdentity = Column(CodingKeys.localIdentity.rawValue)
|
||||
static let deviceId = Column(CodingKeys.deviceId.rawValue)
|
||||
static let serializedRecord = Column(CodingKeys.serializedRecord.rawValue)
|
||||
}
|
||||
}
|
||||
|
||||
struct SessionStore {
|
||||
func hasSessionRecords(
|
||||
forRecipientId recipientId: SignalRecipient.RowId,
|
||||
localIdentity: OWSIdentity,
|
||||
tx: DBReadTransaction,
|
||||
) -> Bool {
|
||||
let sessionRecords = fetchSessionRecords(
|
||||
forRecipientId: recipientId,
|
||||
localIdentity: localIdentity,
|
||||
tx: tx,
|
||||
)
|
||||
return !sessionRecords.isEmpty
|
||||
}
|
||||
|
||||
func mergeRecipientId(
|
||||
_ recipientId: SignalRecipient.RowId,
|
||||
into targetRecipientId: SignalRecipient.RowId,
|
||||
localIdentity: OWSIdentity,
|
||||
tx: DBWriteTransaction,
|
||||
) {
|
||||
if hasSessionRecords(forRecipientId: targetRecipientId, localIdentity: localIdentity, tx: tx) {
|
||||
// There's already sessions -- prefers those instead of ours.
|
||||
deleteSessions(forRecipientId: recipientId, localIdentity: localIdentity, tx: tx)
|
||||
} else {
|
||||
// There's no sessions -- move ours and reuse them.
|
||||
let sessionRecords = fetchSessionRecords(
|
||||
forRecipientId: recipientId,
|
||||
localIdentity: localIdentity,
|
||||
tx: tx,
|
||||
)
|
||||
for var sessionRecord in sessionRecords {
|
||||
sessionRecord.recipientId = targetRecipientId
|
||||
failIfThrows { try sessionRecord.update(tx.database) }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func buildQuery(
|
||||
recipientId: SignalRecipient.RowId,
|
||||
localIdentity: OWSIdentity,
|
||||
deviceId: DeviceId? = nil,
|
||||
) -> QueryInterfaceRequest<SessionRecord> {
|
||||
var result = SessionRecord.filter(SessionRecord.Columns.recipientId == recipientId)
|
||||
result = result.filter(SessionRecord.Columns.localIdentity == localIdentity.rawValue)
|
||||
if let deviceId {
|
||||
result = result.filter(SessionRecord.Columns.deviceId == deviceId.rawValue)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
private func fetchSessionRecords(
|
||||
forRecipientId recipientId: SignalRecipient.RowId,
|
||||
localIdentity: OWSIdentity,
|
||||
deviceId: DeviceId? = nil,
|
||||
tx: DBReadTransaction,
|
||||
) -> [SessionRecord] {
|
||||
return failIfThrows {
|
||||
return try buildQuery(
|
||||
recipientId: recipientId,
|
||||
localIdentity: localIdentity,
|
||||
deviceId: deviceId,
|
||||
).fetchAll(tx.database)
|
||||
}
|
||||
}
|
||||
|
||||
func fetchSession(
|
||||
forRecipientId recipientId: SignalRecipient.RowId,
|
||||
localIdentity: OWSIdentity,
|
||||
deviceId: DeviceId,
|
||||
tx: DBReadTransaction,
|
||||
) throws -> LibSignalClient.SessionRecord? {
|
||||
return try (fetchSessionRecords(
|
||||
forRecipientId: recipientId,
|
||||
localIdentity: localIdentity,
|
||||
deviceId: deviceId,
|
||||
tx: tx,
|
||||
).first?.serializedRecord).map(LibSignalClient.SessionRecord.init(bytes:))
|
||||
}
|
||||
|
||||
func archiveSessions(
|
||||
forRecipientId recipientId: SignalRecipient.RowId,
|
||||
localIdentity: OWSIdentity,
|
||||
tx: DBWriteTransaction,
|
||||
) {
|
||||
_archiveSessions(forRecipientId: recipientId, localIdentity: localIdentity, deviceId: nil, tx: tx)
|
||||
}
|
||||
|
||||
fileprivate func _archiveSessions(
|
||||
forRecipientId recipientId: SignalRecipient.RowId,
|
||||
localIdentity: OWSIdentity,
|
||||
deviceId: DeviceId?,
|
||||
tx: DBWriteTransaction,
|
||||
) {
|
||||
let sessionRecords = fetchSessionRecords(
|
||||
forRecipientId: recipientId,
|
||||
localIdentity: localIdentity,
|
||||
deviceId: deviceId,
|
||||
tx: tx,
|
||||
)
|
||||
for var sessionRecord in sessionRecords {
|
||||
guard let serializedRecord = sessionRecord.serializedRecord else {
|
||||
Logger.warn("couldn't decode legacy session to archive it; leaving it as-is")
|
||||
continue
|
||||
}
|
||||
let libSignalSessionRecord: LibSignalClient.SessionRecord
|
||||
do {
|
||||
libSignalSessionRecord = try LibSignalClient.SessionRecord(bytes: serializedRecord)
|
||||
} catch {
|
||||
owsFailDebug("couldn't decode session to archive it: \(error)")
|
||||
continue
|
||||
}
|
||||
libSignalSessionRecord.archiveCurrentState()
|
||||
sessionRecord.serializedRecord = libSignalSessionRecord.serialize()
|
||||
failIfThrows { try sessionRecord.update(tx.database) }
|
||||
}
|
||||
}
|
||||
|
||||
func deleteSessions(
|
||||
forRecipientId recipientId: SignalRecipient.RowId,
|
||||
localIdentity: OWSIdentity,
|
||||
tx: DBWriteTransaction,
|
||||
) {
|
||||
failIfThrows {
|
||||
_ = try buildQuery(
|
||||
recipientId: recipientId,
|
||||
localIdentity: localIdentity,
|
||||
deviceId: nil,
|
||||
).deleteAll(tx.database)
|
||||
}
|
||||
}
|
||||
|
||||
func upsertSession(
|
||||
forRecipientId recipientId: SignalRecipient.RowId,
|
||||
deviceId: DeviceId,
|
||||
localIdentity: OWSIdentity,
|
||||
recordData: Data,
|
||||
tx: DBWriteTransaction,
|
||||
) {
|
||||
failIfThrows {
|
||||
try tx.database.execute(
|
||||
sql: """
|
||||
INSERT OR REPLACE INTO \(SessionRecord.databaseTableName) (
|
||||
\(SessionRecord.Columns.recipientId.name),
|
||||
\(SessionRecord.Columns.deviceId.name),
|
||||
\(SessionRecord.Columns.localIdentity.name),
|
||||
\(SessionRecord.Columns.serializedRecord.name)
|
||||
) VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
arguments: [
|
||||
recipientId,
|
||||
deviceId.rawValue,
|
||||
localIdentity.rawValue,
|
||||
recordData,
|
||||
],
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func deleteAllSessions(tx: DBWriteTransaction) {
|
||||
failIfThrows { _ = try SessionRecord.deleteAll(tx.database) }
|
||||
}
|
||||
}
|
||||
|
||||
public class SessionManagerForIdentity: LibSignalClient.SessionStore {
|
||||
private let identity: OWSIdentity
|
||||
private let recipientIdFinder: RecipientIdFinder
|
||||
private let sessionStore: SessionStore
|
||||
|
||||
init(
|
||||
identity: OWSIdentity,
|
||||
recipientIdFinder: RecipientIdFinder,
|
||||
sessionStore: SessionStore,
|
||||
) {
|
||||
self.identity = identity
|
||||
self.recipientIdFinder = recipientIdFinder
|
||||
self.sessionStore = sessionStore
|
||||
}
|
||||
|
||||
func archiveSession(forServiceId serviceId: ServiceId, deviceId: DeviceId, tx: DBWriteTransaction) {
|
||||
Logger.info("archiving session for \(serviceId).\(deviceId)")
|
||||
self._archiveSessions(
|
||||
recipientIdResult: self.recipientIdFinder.recipientId(for: serviceId, tx: tx),
|
||||
deviceId: deviceId,
|
||||
tx: tx,
|
||||
)
|
||||
}
|
||||
|
||||
public func archiveSessions(forServiceId serviceId: ServiceId, tx: DBWriteTransaction) {
|
||||
Logger.info("archiving all sessions for \(serviceId)")
|
||||
self._archiveSessions(
|
||||
recipientIdResult: self.recipientIdFinder.recipientId(for: serviceId, tx: tx),
|
||||
deviceId: nil,
|
||||
tx: tx,
|
||||
)
|
||||
}
|
||||
|
||||
func archiveSessions(forAddress address: SignalServiceAddress, tx: DBWriteTransaction) {
|
||||
Logger.info("archiving all sessions for \(address)")
|
||||
self._archiveSessions(
|
||||
recipientIdResult: self.recipientIdFinder.recipientId(for: address, tx: tx),
|
||||
deviceId: nil,
|
||||
tx: tx,
|
||||
)
|
||||
}
|
||||
|
||||
private func _archiveSessions(
|
||||
recipientIdResult: Result<SignalRecipient.RowId, RecipientIdError>?,
|
||||
deviceId: DeviceId?,
|
||||
tx: DBWriteTransaction,
|
||||
) {
|
||||
switch recipientIdResult {
|
||||
case .none, .some(.failure(.mustNotUsePniBecauseAciExists)):
|
||||
// There can't possibly be any sessions that need to be archived.
|
||||
return
|
||||
case .some(.success(let recipientId)):
|
||||
self.sessionStore._archiveSessions(
|
||||
forRecipientId: recipientId,
|
||||
localIdentity: self.identity,
|
||||
deviceId: deviceId,
|
||||
tx: tx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
public func deleteSessions(forServiceId serviceId: ServiceId, tx: DBWriteTransaction) {
|
||||
switch self.recipientIdFinder.recipientId(for: serviceId, tx: tx) {
|
||||
case .none, .some(.failure(.mustNotUsePniBecauseAciExists)):
|
||||
// There can't possibly be any sessions that need to be deleted.
|
||||
return
|
||||
case .some(.success(let recipientId)):
|
||||
self.sessionStore.deleteSessions(forRecipientId: recipientId, localIdentity: self.identity, tx: tx)
|
||||
}
|
||||
}
|
||||
|
||||
func loadSession(
|
||||
forServiceId serviceId: ServiceId,
|
||||
deviceId: DeviceId,
|
||||
tx: DBReadTransaction,
|
||||
) throws -> LibSignalClient.SessionRecord? {
|
||||
switch self.recipientIdFinder.recipientId(for: serviceId, tx: tx) {
|
||||
case .none:
|
||||
return nil
|
||||
case .some(.success(let recipientId)):
|
||||
return try self.sessionStore.fetchSession(
|
||||
forRecipientId: recipientId,
|
||||
localIdentity: self.identity,
|
||||
deviceId: deviceId,
|
||||
tx: tx,
|
||||
)
|
||||
case .some(.failure(let error)):
|
||||
switch error {
|
||||
case .mustNotUsePniBecauseAciExists:
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public func loadSession(
|
||||
for address: LibSignalClient.ProtocolAddress,
|
||||
context: any LibSignalClient.StoreContext,
|
||||
) throws -> LibSignalClient.SessionRecord? {
|
||||
return try loadSession(
|
||||
forServiceId: address.serviceId,
|
||||
deviceId: address.deviceIdObj,
|
||||
tx: context.asTransaction,
|
||||
)
|
||||
}
|
||||
|
||||
public func loadExistingSessions(
|
||||
for addresses: [LibSignalClient.ProtocolAddress],
|
||||
context: any LibSignalClient.StoreContext,
|
||||
) throws -> [LibSignalClient.SessionRecord] {
|
||||
return try addresses.map { address in
|
||||
guard let session = try loadSession(for: address, context: context) else {
|
||||
throw SignalError.sessionNotFound("\(address)")
|
||||
}
|
||||
return session
|
||||
}
|
||||
}
|
||||
|
||||
public func storeSession(
|
||||
_ record: LibSignalClient.SessionRecord,
|
||||
for address: LibSignalClient.ProtocolAddress,
|
||||
context: any LibSignalClient.StoreContext,
|
||||
) throws {
|
||||
switch recipientIdFinder.ensureRecipientId(for: address.serviceId, tx: context.asTransaction) {
|
||||
case .success(let recipientId):
|
||||
self.sessionStore.upsertSession(
|
||||
forRecipientId: recipientId,
|
||||
deviceId: address.deviceIdObj,
|
||||
localIdentity: self.identity,
|
||||
recordData: record.serialize(),
|
||||
tx: context.asTransaction,
|
||||
)
|
||||
case .failure(let error):
|
||||
switch error {
|
||||
case .mustNotUsePniBecauseAciExists:
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -6,7 +6,7 @@
|
||||
/// Wraps the stores for 1:1 sessions that use the Signal Protocol (Double Ratchet + X3DH).
|
||||
|
||||
public struct SignalProtocolStore {
|
||||
public let sessionStore: SignalSessionStore
|
||||
public let sessionStore: SessionManagerForIdentity
|
||||
public let preKeyStore: PreKeyStoreImpl
|
||||
public let signedPreKeyStore: SignedPreKeyStoreImpl
|
||||
public let kyberPreKeyStore: KyberPreKeyStoreImpl
|
||||
@ -16,9 +16,10 @@ public struct SignalProtocolStore {
|
||||
identity: OWSIdentity,
|
||||
preKeyStore: PreKeyStore,
|
||||
recipientIdFinder: RecipientIdFinder,
|
||||
sessionStore: SessionStore,
|
||||
) -> Self {
|
||||
return Self(
|
||||
sessionStore: SSKSessionStore(for: identity, recipientIdFinder: recipientIdFinder),
|
||||
sessionStore: SessionManagerForIdentity(identity: identity, recipientIdFinder: recipientIdFinder, sessionStore: sessionStore),
|
||||
preKeyStore: PreKeyStoreImpl(for: identity, preKeyStore: preKeyStore),
|
||||
signedPreKeyStore: SignedPreKeyStoreImpl(for: identity, preKeyStore: preKeyStore),
|
||||
kyberPreKeyStore: KyberPreKeyStoreImpl(for: identity, dateProvider: dateProvider, preKeyStore: preKeyStore),
|
||||
@ -31,6 +32,7 @@ public struct SignalProtocolStoreManager {
|
||||
let aciProtocolStore: SignalProtocolStore
|
||||
let pniProtocolStore: SignalProtocolStore
|
||||
let preKeyStore: PreKeyStore
|
||||
let sessionStore: SessionStore
|
||||
|
||||
public func signalProtocolStore(for identity: OWSIdentity) -> SignalProtocolStore {
|
||||
switch identity {
|
||||
@ -43,11 +45,11 @@ public struct SignalProtocolStoreManager {
|
||||
|
||||
public func removeAllKeys(tx: DBWriteTransaction) {
|
||||
for signalProtocolStore in [aciProtocolStore, pniProtocolStore] {
|
||||
signalProtocolStore.sessionStore.removeAll(tx: tx)
|
||||
signalProtocolStore.preKeyStore.removeMetadata(tx: tx)
|
||||
signalProtocolStore.signedPreKeyStore.removeMetadata(tx: tx)
|
||||
signalProtocolStore.kyberPreKeyStore.removeMetadata(tx: tx)
|
||||
}
|
||||
self.sessionStore.deleteAllSessions(tx: tx)
|
||||
self.preKeyStore.removeAll(tx: tx)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,62 +0,0 @@
|
||||
//
|
||||
// Copyright 2023 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
import Foundation
|
||||
public import LibSignalClient
|
||||
|
||||
public protocol SignalSessionStore: LibSignalClient.SessionStore {
|
||||
func mightContainSession(
|
||||
for recipient: SignalRecipient,
|
||||
tx: DBReadTransaction
|
||||
) -> Bool
|
||||
|
||||
func mergeRecipient(
|
||||
_ recipient: SignalRecipient,
|
||||
into targetRecipient: SignalRecipient,
|
||||
tx: DBWriteTransaction
|
||||
)
|
||||
|
||||
func archiveAllSessions(
|
||||
for serviceId: ServiceId,
|
||||
tx: DBWriteTransaction
|
||||
)
|
||||
|
||||
/// Deprecated. Prefer the variant that accepts a ServiceId.
|
||||
func archiveAllSessions(
|
||||
for address: SignalServiceAddress,
|
||||
tx: DBWriteTransaction
|
||||
)
|
||||
|
||||
func archiveSession(
|
||||
for serviceId: ServiceId,
|
||||
deviceId: DeviceId,
|
||||
tx: DBWriteTransaction
|
||||
)
|
||||
|
||||
func loadSession(
|
||||
for serviceId: ServiceId,
|
||||
deviceId: DeviceId,
|
||||
tx: DBReadTransaction
|
||||
) throws -> SessionRecord?
|
||||
|
||||
func loadSession(
|
||||
for address: ProtocolAddress,
|
||||
context: StoreContext
|
||||
) throws -> SessionRecord?
|
||||
|
||||
func resetSessionStore(tx: DBWriteTransaction)
|
||||
|
||||
func deleteAllSessions(
|
||||
for serviceId: ServiceId,
|
||||
tx: DBWriteTransaction
|
||||
)
|
||||
|
||||
func deleteAllSessions(
|
||||
for recipientUniqueId: RecipientUniqueId,
|
||||
tx: DBWriteTransaction
|
||||
)
|
||||
|
||||
func removeAll(tx: DBWriteTransaction)
|
||||
}
|
||||
@ -400,6 +400,7 @@ public enum DatabaseRecovery {
|
||||
StoryRecipient.databaseTableName,
|
||||
PreKey.databaseTableName,
|
||||
KyberPreKeyUseRecord.databaseTableName,
|
||||
SignalServiceKit.SessionRecord.databaseTableName,
|
||||
]
|
||||
|
||||
/// Copy tables that must be copied flawlessly. Operation throws if any tables fail.
|
||||
|
||||
@ -315,6 +315,7 @@ public class GRDBSchemaMigrator {
|
||||
case addPinnedAtTimestampToPinnedMessageTable
|
||||
case dropTSAttachment
|
||||
case deprecateStoredShouldStartExpireTimer
|
||||
case addSession
|
||||
|
||||
// NOTE: Every time we add a migration id, consider
|
||||
// incrementing grdbSchemaVersionLatest.
|
||||
@ -437,7 +438,7 @@ public class GRDBSchemaMigrator {
|
||||
}
|
||||
|
||||
public static let grdbSchemaVersionDefault: UInt = 0
|
||||
public static let grdbSchemaVersionLatest: UInt = 137
|
||||
public static let grdbSchemaVersionLatest: UInt = 138
|
||||
|
||||
private class DatabaseMigratorWrapper {
|
||||
var migrator = DatabaseMigrator()
|
||||
@ -4952,6 +4953,15 @@ public class GRDBSchemaMigrator {
|
||||
return .success(())
|
||||
}
|
||||
|
||||
migrator.registerMigration(.addSession) { tx in
|
||||
try createSession(tx: tx)
|
||||
if BuildFlags.migrateDeprecatedSessions {
|
||||
try migrateSessions(tx: tx)
|
||||
}
|
||||
try dropOldSessions(tx: tx)
|
||||
return .success(())
|
||||
}
|
||||
|
||||
// MARK: - Schema Migration Insertion Point
|
||||
}
|
||||
|
||||
@ -7250,6 +7260,105 @@ public class GRDBSchemaMigrator {
|
||||
arguments: ["TSStorageUserAccountCollection"],
|
||||
) ?? false
|
||||
}
|
||||
|
||||
static func createSession(tx: DBWriteTransaction) throws {
|
||||
try tx.database.create(table: "Session") {
|
||||
$0.column("id", .integer).primaryKey().notNull()
|
||||
$0.column("recipientId", .integer).notNull()
|
||||
.references("model_SignalRecipient", column: "id", onDelete: .cascade, onUpdate: .cascade)
|
||||
$0.column("localIdentity", .integer).notNull()
|
||||
$0.column("deviceId", .integer).notNull()
|
||||
$0.column("serializedRecord", .blob)
|
||||
$0.check(sql: #"1 <= "deviceId" AND "deviceId" <= 127"#)
|
||||
}
|
||||
// For fetching session(s) for a recipient.
|
||||
try tx.database.create(
|
||||
index: "Session_Unique",
|
||||
on: "Session",
|
||||
columns: ["recipientId", "localIdentity", "deviceId"],
|
||||
options: [.unique],
|
||||
)
|
||||
}
|
||||
|
||||
static func migrateSessions(tx: DBWriteTransaction) throws {
|
||||
// If these ever change, you'll need to add a new migration to update the
|
||||
// Session table and replace the old constants with the new constants.
|
||||
assert(OWSIdentity.aci.rawValue == 0)
|
||||
assert(OWSIdentity.pni.rawValue == 1)
|
||||
|
||||
try migrateSessions(in: "TSStorageManagerSessionStoreCollection", identity: 0, tx: tx)
|
||||
try migrateSessions(in: "TSStorageManagerPNISessionStoreCollection", identity: 1, tx: tx)
|
||||
}
|
||||
|
||||
static func migrateSessions(
|
||||
in collection: String,
|
||||
identity: Int64,
|
||||
tx: DBWriteTransaction,
|
||||
) throws {
|
||||
let keys = try String.fetchAll(
|
||||
tx.database,
|
||||
sql: "SELECT key FROM keyvalue WHERE collection = ?",
|
||||
arguments: [collection],
|
||||
)
|
||||
for key in keys { try autoreleasepool {
|
||||
let dataValue = try Data.fetchOne(
|
||||
tx.database,
|
||||
sql: "SELECT value FROM keyvalue WHERE collection = ? AND key = ?",
|
||||
arguments: [collection, key],
|
||||
)!
|
||||
let sessionDictionary: [Int32: Data?]
|
||||
let decodedValue = try? NSKeyedUnarchiver.unarchivedObject(
|
||||
ofClasses: [NSDictionary.self, NSNumber.self, NSData.self],
|
||||
from: dataValue,
|
||||
) as? [Int32: Data]
|
||||
if let decodedValue {
|
||||
sessionDictionary = decodedValue
|
||||
} else {
|
||||
// We expect some failures (for legacy data), and if there are failures, we
|
||||
// want to remember that there was a session, even though we can't do
|
||||
// anything with that session. (See also `hasSessionRecords`).
|
||||
Logger.warn("Storing nil for \(key) in \(collection) that couldn't be decoded")
|
||||
sessionDictionary = [1: nil]
|
||||
}
|
||||
let recipientId = try Int64.fetchOne(
|
||||
tx.database,
|
||||
sql: "SELECT id FROM model_SignalRecipient WHERE uniqueId = ?",
|
||||
arguments: [key],
|
||||
)
|
||||
guard let recipientId else {
|
||||
// If we can't find the SignalRecipient, these sessions aren't reachable,
|
||||
// so we don't need to keep them. (Foreign key constraints will enforce
|
||||
// this moving forward.)
|
||||
Logger.warn("Skipping \(key) in \(collection) that's been orphaned")
|
||||
return
|
||||
}
|
||||
for (deviceId, serializedRecord) in sessionDictionary {
|
||||
guard deviceId >= 1 && deviceId <= 127 else {
|
||||
Logger.warn("Skipping \(deviceId) for \(key) in \(collection) that's not valid")
|
||||
continue
|
||||
}
|
||||
try tx.database.execute(
|
||||
sql: "INSERT INTO Session (recipientId, localIdentity, deviceId, serializedRecord) VALUES (?, ?, ?, ?)",
|
||||
arguments: [
|
||||
recipientId,
|
||||
identity,
|
||||
deviceId,
|
||||
serializedRecord,
|
||||
],
|
||||
)
|
||||
}
|
||||
}}
|
||||
}
|
||||
|
||||
static func dropOldSessions(tx: DBWriteTransaction) throws {
|
||||
let collections = [
|
||||
"TSStorageManagerSessionStoreCollection",
|
||||
"TSStorageManagerPNISessionStoreCollection",
|
||||
]
|
||||
for collection in collections {
|
||||
try tx.database.execute(sql: "DELETE FROM keyvalue WHERE collection = ?", arguments: [collection])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: -
|
||||
|
||||
@ -34,29 +34,50 @@ public final class RecipientIdFinder {
|
||||
guard let recipient = recipientDatabaseTable.fetchRecipient(serviceId: serviceId, transaction: tx) else {
|
||||
return nil
|
||||
}
|
||||
return recipientUniqueIdResult(for: serviceId, recipient: recipient)
|
||||
return validateRecipient(recipient, for: serviceId).map(\.uniqueId)
|
||||
}
|
||||
|
||||
public func recipientUniqueId(for address: SignalServiceAddress, tx: DBReadTransaction) -> Result<RecipientUniqueId, RecipientIdError>? {
|
||||
guard
|
||||
let recipient = DependenciesBridge.shared.recipientDatabaseTable
|
||||
.fetchRecipient(address: address, tx: tx)
|
||||
else {
|
||||
guard let recipient = recipientDatabaseTable.fetchRecipient(address: address, tx: tx) else {
|
||||
return nil
|
||||
}
|
||||
|
||||
return recipientUniqueIdResult(for: address.serviceId, recipient: recipient)
|
||||
return validateRecipient(recipient, for: address.serviceId).map(\.uniqueId)
|
||||
}
|
||||
|
||||
public func ensureRecipientUniqueId(for serviceId: ServiceId, tx: DBWriteTransaction) -> Result<RecipientUniqueId, RecipientIdError> {
|
||||
let recipient = recipientFetcher.fetchOrCreate(serviceId: serviceId, tx: tx)
|
||||
return recipientUniqueIdResult(for: serviceId, recipient: recipient)
|
||||
return ensureRecipient(for: serviceId, tx: tx).map(\.uniqueId)
|
||||
}
|
||||
|
||||
private func recipientUniqueIdResult(for serviceId: ServiceId?, recipient: SignalRecipient) -> Result<RecipientUniqueId, RecipientIdError> {
|
||||
public func recipientId(for serviceId: ServiceId, tx: DBReadTransaction) -> Result<SignalRecipient.RowId, RecipientIdError>? {
|
||||
guard let recipient = recipientDatabaseTable.fetchRecipient(serviceId: serviceId, transaction: tx) else {
|
||||
return nil
|
||||
}
|
||||
return validateRecipient(recipient, for: serviceId).map(\.id)
|
||||
}
|
||||
|
||||
public func recipientId(for address: SignalServiceAddress, tx: DBReadTransaction) -> Result<SignalRecipient.RowId, RecipientIdError>? {
|
||||
guard let recipient = recipientDatabaseTable.fetchRecipient(address: address, tx: tx) else {
|
||||
return nil
|
||||
}
|
||||
return validateRecipient(recipient, for: address.serviceId).map(\.id)
|
||||
}
|
||||
|
||||
public func ensureRecipientId(for serviceId: ServiceId, tx: DBWriteTransaction) -> Result<SignalRecipient.RowId, RecipientIdError> {
|
||||
return ensureRecipient(for: serviceId, tx: tx).map(\.id)
|
||||
}
|
||||
|
||||
public func ensureRecipient(for serviceId: ServiceId, tx: DBWriteTransaction) -> Result<SignalRecipient, RecipientIdError> {
|
||||
let recipient = recipientFetcher.fetchOrCreate(serviceId: serviceId, tx: tx)
|
||||
return validateRecipient(recipient, for: serviceId)
|
||||
}
|
||||
|
||||
private func validateRecipient(
|
||||
_ recipient: SignalRecipient,
|
||||
for serviceId: ServiceId?,
|
||||
) -> Result<SignalRecipient, RecipientIdError> {
|
||||
if serviceId is Pni, recipient.aciString != nil {
|
||||
return .failure(.mustNotUsePniBecauseAciExists)
|
||||
}
|
||||
return .success(recipient.uniqueId)
|
||||
return .success(recipient)
|
||||
}
|
||||
}
|
||||
|
||||
@ -266,12 +266,16 @@ struct FakeSignalClient: TestSignalClient {
|
||||
struct LocalSignalClient: TestSignalClient {
|
||||
let identity: OWSIdentity
|
||||
let _preKeyStore: PreKeyStore
|
||||
let _sessionStore: SSKSessionStore
|
||||
let _sessionStore: SessionManagerForIdentity
|
||||
|
||||
init(identity: OWSIdentity = .aci) {
|
||||
self.identity = identity
|
||||
self._preKeyStore = PreKeyStore()
|
||||
self._sessionStore = SSKSessionStore(for: identity, recipientIdFinder: DependenciesBridge.shared.recipientIdFinder)
|
||||
self._sessionStore = SessionManagerForIdentity(
|
||||
identity: identity,
|
||||
recipientIdFinder: DependenciesBridge.shared.recipientIdFinder,
|
||||
sessionStore: SessionStore(),
|
||||
)
|
||||
}
|
||||
|
||||
var identityKeyPair: ECKeyPair {
|
||||
|
||||
@ -45,14 +45,28 @@ final class PreKeyTaskTests: SSKBaseTest {
|
||||
mockAPIClient = .init()
|
||||
mockDateProvider = .init()
|
||||
mockDb = InMemoryDB()
|
||||
let sessionStore = SignalServiceKit.SessionStore()
|
||||
|
||||
mockPreKeyStore = PreKeyStore()
|
||||
mockAciProtocolStore = .mock(identity: .aci, preKeyStore: mockPreKeyStore)
|
||||
mockPniProtocolStore = .mock(identity: .pni, preKeyStore: mockPreKeyStore)
|
||||
mockAciProtocolStore = .build(
|
||||
dateProvider: mockDateProvider.targetDate,
|
||||
identity: .aci,
|
||||
preKeyStore: mockPreKeyStore,
|
||||
recipientIdFinder: recipientIdFinder,
|
||||
sessionStore: sessionStore,
|
||||
)
|
||||
mockPniProtocolStore = .build(
|
||||
dateProvider: mockDateProvider.targetDate,
|
||||
identity: .pni,
|
||||
preKeyStore: mockPreKeyStore,
|
||||
recipientIdFinder: recipientIdFinder,
|
||||
sessionStore: sessionStore,
|
||||
)
|
||||
mockProtocolStoreManager = SignalProtocolStoreManager(
|
||||
aciProtocolStore: mockAciProtocolStore,
|
||||
pniProtocolStore: mockPniProtocolStore,
|
||||
preKeyStore: mockPreKeyStore,
|
||||
sessionStore: sessionStore,
|
||||
)
|
||||
|
||||
taskManager = PreKeyTaskManager(
|
||||
|
||||
@ -28,16 +28,13 @@ private class MockStorageServiceManager: StorageServiceManager {
|
||||
}
|
||||
|
||||
private class TestDependencies {
|
||||
let aciSessionStore: SignalSessionStore
|
||||
var aciSessionStoreKeyValueStore: KeyValueStore {
|
||||
KeyValueStore(collection: "TSStorageManagerSessionStoreCollection")
|
||||
}
|
||||
let identityManager: MockIdentityManager
|
||||
let mockDB = InMemoryDB()
|
||||
let recipientMerger: RecipientMerger
|
||||
let recipientDatabaseTable = RecipientDatabaseTable()
|
||||
let recipientFetcher: RecipientFetcher
|
||||
let recipientIdFinder: RecipientIdFinder
|
||||
let sessionStore: SignalServiceKit.SessionStore
|
||||
let threadAssociatedDataStore: MockThreadAssociatedDataStore
|
||||
let threadStore: MockThreadStore
|
||||
let threadMerger: ThreadMerger
|
||||
@ -50,10 +47,10 @@ private class TestDependencies {
|
||||
searchableNameIndexer: searchableNameIndexer,
|
||||
)
|
||||
recipientIdFinder = RecipientIdFinder(recipientDatabaseTable: recipientDatabaseTable, recipientFetcher: recipientFetcher)
|
||||
aciSessionStore = SSKSessionStore(for: .aci, recipientIdFinder: recipientIdFinder)
|
||||
identityManager = MockIdentityManager(recipientIdFinder: recipientIdFinder)
|
||||
identityManager.recipientIdentities = [:]
|
||||
identityManager.sessionSwitchoverMessages = []
|
||||
sessionStore = SignalServiceKit.SessionStore()
|
||||
threadAssociatedDataStore = MockThreadAssociatedDataStore()
|
||||
threadStore = MockThreadStore()
|
||||
threadMerger = ThreadMerger.forUnitTests(
|
||||
@ -61,7 +58,6 @@ private class TestDependencies {
|
||||
threadStore: threadStore
|
||||
)
|
||||
recipientMerger = RecipientMergerImpl(
|
||||
aciSessionStore: aciSessionStore,
|
||||
blockedRecipientStore: BlockedRecipientStore(),
|
||||
identityManager: identityManager,
|
||||
observers: RecipientMergerImpl.Observers(
|
||||
@ -72,6 +68,7 @@ private class TestDependencies {
|
||||
recipientDatabaseTable: recipientDatabaseTable,
|
||||
recipientFetcher: recipientFetcher,
|
||||
searchableNameIndexer: searchableNameIndexer,
|
||||
sessionStore: sessionStore,
|
||||
storageServiceManager: storageServiceManager,
|
||||
storyRecipientStore: StoryRecipientStore()
|
||||
)
|
||||
@ -253,7 +250,13 @@ class RecipientMergerTest: XCTestCase {
|
||||
createdAt: Date(),
|
||||
verificationState: .default
|
||||
)
|
||||
d.aciSessionStoreKeyValueStore.setData(Data(), key: recipient.uniqueId, transaction: tx)
|
||||
d.sessionStore.upsertSession(
|
||||
forRecipientId: recipient.id,
|
||||
deviceId: .primary,
|
||||
localIdentity: .aci,
|
||||
recordData: Data(),
|
||||
tx: tx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -266,7 +269,7 @@ class RecipientMergerTest: XCTestCase {
|
||||
|
||||
XCTAssertEqual(d.identityManager.identityChangeInfoMessages, testCase.shouldInsertEvent ? [ac1] : [])
|
||||
XCTAssertEqual(try! d.identityManager.identityKey(for: ac1, tx: tx), ik1)
|
||||
XCTAssertTrue(d.aciSessionStore.mightContainSession(for: mergedRecipient, tx: tx))
|
||||
XCTAssertTrue(d.sessionStore.hasSessionRecords(forRecipientId: mergedRecipient.id, localIdentity: .aci, tx: tx))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -402,7 +405,13 @@ class RecipientMergerTest: XCTestCase {
|
||||
d.mockDB.write { tx in
|
||||
for recipientNumber in testCase.hasSession {
|
||||
let recipient = recipients.dropFirst(recipientNumber - 1).first!
|
||||
d.aciSessionStoreKeyValueStore.setData(Data(), key: recipient.uniqueId, transaction: tx)
|
||||
d.sessionStore.upsertSession(
|
||||
forRecipientId: recipient.id,
|
||||
deviceId: .primary,
|
||||
localIdentity: .aci,
|
||||
recordData: Data(),
|
||||
tx: tx,
|
||||
)
|
||||
let thread = TSContactThread(contactAddress: SignalServiceAddress(
|
||||
serviceId: recipient.aci ?? recipient.pni,
|
||||
phoneNumber: recipient.phoneNumber?.stringValue,
|
||||
@ -444,10 +453,22 @@ class RecipientMergerTest: XCTestCase {
|
||||
let d = TestDependencies()
|
||||
let aciRecipient = d.mockDB.write { tx in
|
||||
let aciRecipient = try! SignalRecipient.insertRecord(aci: aci, tx: tx)
|
||||
d.aciSessionStoreKeyValueStore.setData(Data(), key: aciRecipient.uniqueId, transaction: tx)
|
||||
d.sessionStore.upsertSession(
|
||||
forRecipientId: aciRecipient.id,
|
||||
deviceId: .primary,
|
||||
localIdentity: .aci,
|
||||
recordData: Data(),
|
||||
tx: tx,
|
||||
)
|
||||
|
||||
let pniRecipient = try! SignalRecipient.insertRecord(phoneNumber: phoneNumber, pni: pni, tx: tx)
|
||||
d.aciSessionStoreKeyValueStore.setData(Data(), key: pniRecipient.uniqueId, transaction: tx)
|
||||
d.sessionStore.upsertSession(
|
||||
forRecipientId: pniRecipient.id,
|
||||
deviceId: .primary,
|
||||
localIdentity: .aci,
|
||||
recordData: Data(),
|
||||
tx: tx,
|
||||
)
|
||||
|
||||
return aciRecipient
|
||||
}
|
||||
|
||||
@ -8,58 +8,6 @@ import LibSignalClient
|
||||
|
||||
@testable import SignalServiceKit
|
||||
|
||||
class SessionStoreTest: SSKBaseTest {
|
||||
func testLegacySessionIsDropped() {
|
||||
@objc(FakeLegacySession) class FakeLegacySession: NSObject, NSCoding {
|
||||
override init() {
|
||||
}
|
||||
|
||||
required init?(coder: NSCoder) {
|
||||
fatalError("should never be deserialized")
|
||||
}
|
||||
|
||||
func encode(with coder: NSCoder) {
|
||||
// no properties
|
||||
}
|
||||
}
|
||||
|
||||
// We have to use the database-based KeyValueStore to test this
|
||||
// because the in-memory one skips the archiving step.
|
||||
let sessionStore = SSKSessionStore(for: .aci, recipientIdFinder: DependenciesBridge.shared.recipientIdFinder)
|
||||
let recipient = write {
|
||||
DependenciesBridge.shared.recipientFetcher.fetchOrCreate(serviceId: Aci.randomForTesting(), tx: $0)
|
||||
}
|
||||
|
||||
// First make sure that if we write a "valid" session, it can be read.
|
||||
let singleValidSessionData = try! NSKeyedArchiver.archivedData(withRootObject: [1: Data()], requiringSecureCoding: true)
|
||||
write {
|
||||
sessionStore.keyValueStoreForTesting.setData(singleValidSessionData, key: recipient.uniqueId, transaction: $0)
|
||||
}
|
||||
|
||||
read {
|
||||
XCTAssertTrue(sessionStore.mightContainSession(for: recipient, tx: $0))
|
||||
XCTAssertNotNil(try! sessionStore.loadSession(for: recipient.aci!, deviceId: DeviceId(validating: 1)!, tx: $0))
|
||||
}
|
||||
|
||||
// Then imitate a session store with a mix of legacy and modern sessions.
|
||||
let sessions: NSDictionary = [1: FakeLegacySession(), 2: Data()]
|
||||
let archiver = NSKeyedArchiver(requiringSecureCoding: false)
|
||||
archiver.setClassName("SSKLegacySessionClassThatNoLongerExists", for: FakeLegacySession.self)
|
||||
archiver.encode(sessions, forKey: NSKeyedArchiveRootObjectKey)
|
||||
|
||||
write {
|
||||
sessionStore.keyValueStoreForTesting.setData(archiver.encodedData, key: recipient.uniqueId, transaction: $0)
|
||||
}
|
||||
|
||||
read {
|
||||
// There's something in the store...
|
||||
XCTAssertTrue(sessionStore.mightContainSession(for: recipient, tx: $0))
|
||||
// ...but it turns into nil on load.
|
||||
XCTAssertNil(try! sessionStore.loadSession(for: recipient.aci!, deviceId: DeviceId(validating: 2)!, tx: $0))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class SessionStoreTest2: XCTestCase {
|
||||
func testMaxUnacknowledgedSessionAge() throws {
|
||||
let bob_address = try ProtocolAddress(name: "+14155550100", deviceId: 1)
|
||||
|
||||
@ -1160,4 +1160,118 @@ class GRDBSchemaMigratorTest: XCTestCase {
|
||||
XCTAssertEqual(callLinks[2][0] as String?, "Something")
|
||||
}
|
||||
}
|
||||
|
||||
private func keyedArchiverSessionData(deviceIds: [Int32]) -> Data {
|
||||
let sessionDictionary = Dictionary(uniqueKeysWithValues: deviceIds.map { ($0, Data()) })
|
||||
return keyedArchiverData(rootObject: sessionDictionary)
|
||||
}
|
||||
|
||||
func testMigrateSessions() throws {
|
||||
let databaseQueue = DatabaseQueue()
|
||||
try databaseQueue.write { db in
|
||||
try db.execute(sql: """
|
||||
CREATE TABLE keyvalue (collection TEXT NOT NULL, key TEXT NOT NULL, value BLOB NOT NULL);
|
||||
""")
|
||||
|
||||
try db.execute(sql: """
|
||||
CREATE TABLE model_SignalRecipient (id INTEGER PRIMARY KEY, uniqueId TEXT NOT NULL);
|
||||
""")
|
||||
|
||||
let recipient1UniqueId = UUID().uuidString
|
||||
let recipient2UniqueId = UUID().uuidString
|
||||
let recipient3UniqueId = UUID().uuidString
|
||||
let recipient4UniqueId = UUID().uuidString
|
||||
|
||||
try db.execute(
|
||||
sql: "INSERT INTO model_SignalRecipient (id, uniqueId) VALUES (?, ?)",
|
||||
arguments: [1, recipient1UniqueId],
|
||||
)
|
||||
try db.execute(
|
||||
sql: "INSERT INTO model_SignalRecipient (id, uniqueId) VALUES (?, ?)",
|
||||
arguments: [2, recipient2UniqueId],
|
||||
)
|
||||
// Don't insert recipient3UniqueKey.
|
||||
try db.execute(
|
||||
sql: "INSERT INTO model_SignalRecipient (id, uniqueId) VALUES (?, ?)",
|
||||
arguments: [4, recipient4UniqueId],
|
||||
)
|
||||
|
||||
try db.execute(
|
||||
sql: "INSERT INTO keyvalue (collection, key, value) VALUES (?, ?, ?)",
|
||||
arguments: ["TSStorageManagerSessionStoreCollection", recipient1UniqueId, keyedArchiverSessionData(deviceIds: [1, 2, 128])],
|
||||
)
|
||||
try db.execute(
|
||||
sql: "INSERT INTO keyvalue (collection, key, value) VALUES (?, ?, ?)",
|
||||
arguments: ["TSStorageManagerPNISessionStoreCollection", recipient1UniqueId, keyedArchiverSessionData(deviceIds: [0, 2, 3])],
|
||||
)
|
||||
try db.execute(
|
||||
sql: "INSERT INTO keyvalue (collection, key, value) VALUES (?, ?, ?)",
|
||||
arguments: ["TSStorageManagerSessionStoreCollection", recipient2UniqueId, keyedArchiverSessionData(deviceIds: [1])],
|
||||
)
|
||||
try db.execute(
|
||||
sql: "INSERT INTO keyvalue (collection, key, value) VALUES (?, ?, ?)",
|
||||
arguments: ["TSStorageManagerSessionStoreCollection", recipient3UniqueId, keyedArchiverSessionData(deviceIds: [1])],
|
||||
)
|
||||
|
||||
@objc(FakeLegacySession) class FakeLegacySession: NSObject, NSCoding {
|
||||
override init() {}
|
||||
required init?(coder: NSCoder) { fatalError("should never be deserialized") }
|
||||
func encode(with coder: NSCoder) {}
|
||||
}
|
||||
let legacyArchivedData: Data
|
||||
do {
|
||||
let sessionDictionary: [Int32: AnyObject] = [1: FakeLegacySession(), 2: Data() as NSData]
|
||||
let archiver = NSKeyedArchiver(requiringSecureCoding: false)
|
||||
archiver.setClassName("SSKLegacySessionClassThatNoLongerExists", for: FakeLegacySession.self)
|
||||
archiver.encode(sessionDictionary, forKey: NSKeyedArchiveRootObjectKey)
|
||||
legacyArchivedData = archiver.encodedData
|
||||
}
|
||||
try db.execute(
|
||||
sql: "INSERT INTO keyvalue (collection, key, value) VALUES (?, ?, ?)",
|
||||
arguments: ["TSStorageManagerSessionStoreCollection", recipient4UniqueId, legacyArchivedData],
|
||||
)
|
||||
|
||||
do {
|
||||
let tx = DBWriteTransaction(database: db)
|
||||
defer { tx.finalizeTransaction() }
|
||||
try GRDBSchemaMigrator.createSession(tx: tx)
|
||||
try GRDBSchemaMigrator.migrateSessions(tx: tx)
|
||||
try GRDBSchemaMigrator.dropOldSessions(tx: tx)
|
||||
}
|
||||
|
||||
let sessions = try Row.fetchAll(db, sql: "SELECT * FROM Session ORDER BY recipientId, localIdentity, deviceId")
|
||||
|
||||
XCTAssertEqual(sessions.count, 6)
|
||||
|
||||
XCTAssertEqual(sessions[0]["recipientId"] as Int64, 1)
|
||||
XCTAssertEqual(sessions[0]["localIdentity"] as Int64, 0)
|
||||
XCTAssertEqual(sessions[0]["deviceId"] as Int8, 1)
|
||||
XCTAssertEqual(sessions[0]["serializedRecord"] as Data?, Data())
|
||||
|
||||
XCTAssertEqual(sessions[1]["recipientId"] as Int64, 1)
|
||||
XCTAssertEqual(sessions[1]["localIdentity"] as Int64, 0)
|
||||
XCTAssertEqual(sessions[1]["deviceId"] as Int8, 2)
|
||||
XCTAssertEqual(sessions[1]["serializedRecord"] as Data?, Data())
|
||||
|
||||
XCTAssertEqual(sessions[2]["recipientId"] as Int64, 1)
|
||||
XCTAssertEqual(sessions[2]["localIdentity"] as Int64, 1)
|
||||
XCTAssertEqual(sessions[2]["deviceId"] as Int8, 2)
|
||||
XCTAssertEqual(sessions[2]["serializedRecord"] as Data?, Data())
|
||||
|
||||
XCTAssertEqual(sessions[3]["recipientId"] as Int64, 1)
|
||||
XCTAssertEqual(sessions[3]["localIdentity"] as Int64, 1)
|
||||
XCTAssertEqual(sessions[3]["deviceId"] as Int8, 3)
|
||||
XCTAssertEqual(sessions[3]["serializedRecord"] as Data?, Data())
|
||||
|
||||
XCTAssertEqual(sessions[4]["recipientId"] as Int64, 2)
|
||||
XCTAssertEqual(sessions[4]["localIdentity"] as Int64, 0)
|
||||
XCTAssertEqual(sessions[4]["deviceId"] as Int8, 1)
|
||||
XCTAssertEqual(sessions[4]["serializedRecord"] as Data?, Data())
|
||||
|
||||
XCTAssertEqual(sessions[5]["recipientId"] as Int64, 4)
|
||||
XCTAssertEqual(sessions[5]["localIdentity"] as Int64, 0)
|
||||
XCTAssertEqual(sessions[5]["deviceId"] as Int8, 1)
|
||||
XCTAssertEqual(sessions[5]["serializedRecord"] as Data?, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user