// // Copyright 2025 Signal Messenger, LLC // SPDX-License-Identifier: AGPL-3.0-only // import Foundation import GRDB import LibSignalClient struct PreKeyStore { enum Error: Swift.Error { case noPreKeyWithId(UInt32) } let aciStore: PreKeyStoreForIdentity let pniStore: PreKeyStoreForIdentity init() { self.aciStore = PreKeyStoreForIdentity(identity: .aci) self.pniStore = PreKeyStoreForIdentity(identity: .pni) } func forIdentity(_ identity: OWSIdentity) -> PreKeyStoreForIdentity { switch identity { case .aci: aciStore case .pni: pniStore } } func removeAll(tx: DBWriteTransaction) { Logger.info("") failIfThrows { _ = try PreKeyRecord.deleteAll(tx.database) } } func allocatePreKeyIds( in metadataStore: KeyValueStore, lastPreKeyIdKey: String, count: Int, tx: DBWriteTransaction, ) -> ClosedRange { let lastPreKeyId = metadataStore.getInt(lastPreKeyIdKey, transaction: tx).flatMap(UInt32.init(exactly:)) let preKeyIds = PreKeyId.nextPreKeyIds(lastPreKeyId: lastPreKeyId, count: count) metadataStore.setInt(Int(preKeyIds.upperBound), key: lastPreKeyIdKey, transaction: tx) return preKeyIds } func setReplacedAtIfNil( to now: Date, in namespace: PreKeyRecord.Namespace, identity: OWSIdentity, isOneTime: Bool, exceptFor exceptForPreKeyIds: [UInt32], tx: DBWriteTransaction, ) { let keyIdColumn = Column(PreKeyRecord.CodingKeys.keyId.rawValue) let replacedAtColumn = Column(PreKeyRecord.CodingKeys.replacedAt.rawValue) let isOneTimeColumn = Column(PreKeyRecord.CodingKeys.isOneTime.rawValue) failIfThrows { _ = try PreKeyRecord.baseQuery(in: namespace, identity: identity) .filter(isOneTimeColumn == isOneTime) .filter(replacedAtColumn == nil) .filter(!exceptForPreKeyIds.contains(keyIdColumn)) .updateAll(tx.database, [replacedAtColumn.set(to: Int64(now.timeIntervalSince1970))]) } } func cullPreKeys(gracePeriod: TimeInterval, tx: DBWriteTransaction) { let now = Date().timeIntervalSince1970 let delay = PreKeyManagerImpl.Constants.maxUnacknowledgedSessionAge + gracePeriod let replacedAt = Column(PreKeyRecord.CodingKeys.replacedAt.rawValue) failIfThrows { var rowIds = [Int64]() let query = PreKeyRecord.filter(replacedAt < Int64(now - delay) || replacedAt > Int64(now + delay)) let cursor = try query.fetchCursor(tx.database) while let preKey = try cursor.next() { Logger.info("removing prekey \(preKey.namespace) \(preKey.keyId), replacedAt \(preKey.replacedAt!)") rowIds.append(preKey.rowId) } for rowId in rowIds { try PreKeyRecord.deleteOne(tx.database, key: rowId) } } } } class PreKeyStoreForIdentity { private let identity: OWSIdentity init(identity: OWSIdentity) { self.identity = identity } private func baseQuery(in namespace: PreKeyRecord.Namespace) -> QueryInterfaceRequest { return PreKeyRecord.baseQuery(in: namespace, identity: self.identity) } func fetchPreKey(in namespace: PreKeyRecord.Namespace, for keyId: UInt32, tx: DBReadTransaction) -> PreKeyRecord? { failIfThrows { do { return try baseQuery(in: namespace) .filter(Column(PreKeyRecord.CodingKeys.keyId.rawValue) == keyId) .fetchOne(tx.database) } catch { throw error.grdbErrorForLogging } } } private func fetchSerializedRecord(in namespace: PreKeyRecord.Namespace, for keyId: UInt32, tx: DBReadTransaction) throws -> Data { let preKey = fetchPreKey(in: namespace, for: keyId, tx: tx) guard let serializedRecord = preKey?.serializedRecord else { throw PreKeyStore.Error.noPreKeyWithId(keyId) } return serializedRecord } func upsertPreKeyRecord( _ serializedRecord: Data, keyId: UInt32, in namespace: PreKeyRecord.Namespace, isOneTime: Bool, tx: DBWriteTransaction, ) { failIfThrows { do { // Key IDs intentionally aren't large enough to avoid conflicts when // sampling randomly. Clients don't generate conflicting keys, though // certain operations (e.g., change number) may produce harmless conflicts. // We use "OR REPLACE" to keep the latest key if such a conflict occurs. try tx.database.execute( sql: """ INSERT OR REPLACE INTO \(PreKeyRecord.databaseTableName) ( \(PreKeyRecord.CodingKeys.namespace.rawValue), \(PreKeyRecord.CodingKeys.identity.rawValue), \(PreKeyRecord.CodingKeys.keyId.rawValue), \(PreKeyRecord.CodingKeys.isOneTime.rawValue), \(PreKeyRecord.CodingKeys.serializedRecord.rawValue) ) VALUES (?, ?, ?, ?, ?) """, arguments: [namespace.rawValue, self.identity.rawValue, keyId, isOneTime, serializedRecord], ) } catch { throw error.grdbErrorForLogging } } } func removePreKey(in namespace: PreKeyRecord.Namespace, keyId: UInt32, tx: DBWriteTransaction) { let keyIdColumn = Column(PreKeyRecord.CodingKeys.keyId.rawValue) failIfThrows { _ = try baseQuery(in: namespace).filter(keyIdColumn == keyId).deleteAll(tx.database) } } #if TESTABLE_BUILD func fetchCount(in namespace: PreKeyRecord.Namespace, isOneTime: Bool, tx: DBReadTransaction) throws -> Int { return try baseQuery(in: namespace) .filter(Column(PreKeyRecord.CodingKeys.isOneTime.rawValue) == isOneTime) .fetchCount(tx.database) } #endif } extension PreKeyStoreForIdentity: LibSignalClient.PreKeyStore { func loadPreKey(id: UInt32, context: any StoreContext) throws -> LibSignalClient.PreKeyRecord { return try LibSignalClient.PreKeyRecord(bytes: fetchSerializedRecord(in: .oneTime, for: id, tx: context.asTransaction)) } func removePreKey(id: UInt32, context: any StoreContext) throws { removePreKey(in: .oneTime, keyId: id, tx: context.asTransaction) } func storePreKey(_ record: LibSignalClient.PreKeyRecord, id: UInt32, context: any StoreContext) throws { // This is currently unused (and needs `replacedAt` support). owsFail("Not supported.") } } extension PreKeyStoreForIdentity: LibSignalClient.SignedPreKeyStore { func loadSignedPreKey(id: UInt32, context: any StoreContext) throws -> LibSignalClient.SignedPreKeyRecord { return try LibSignalClient.SignedPreKeyRecord(bytes: fetchSerializedRecord(in: .signed, for: id, tx: context.asTransaction)) } func storeSignedPreKey(_ record: LibSignalClient.SignedPreKeyRecord, id: UInt32, context: any StoreContext) throws { // This is currently unused (and needs `replacedAt` support). owsFail("Not supported.") } } extension PreKeyStoreForIdentity: LibSignalClient.KyberPreKeyStore { func loadKyberPreKey(id: UInt32, context: any StoreContext) throws -> LibSignalClient.KyberPreKeyRecord { return try LibSignalClient.KyberPreKeyRecord(bytes: fetchSerializedRecord(in: .kyber, for: id, tx: context.asTransaction)) } func markKyberPreKeyUsed(id keyId: UInt32, signedPreKeyId: UInt32, baseKey: PublicKey, context: any StoreContext) throws { let tx = context.asTransaction guard let preKey = fetchPreKey(in: .kyber, for: keyId, tx: tx) else { throw PreKeyStore.Error.noPreKeyWithId(keyId) } if preKey.isOneTime { removePreKey(in: .kyber, keyId: keyId, tx: tx) } else { do { try KyberPreKeyUseRecord( kyberRowId: preKey.rowId, signedPreKeyIdentity: self.identity, signedPreKeyId: signedPreKeyId, baseKey: baseKey.serialize(), ).insert(tx.database) } catch { let error = error.grdbErrorForLogging switch error { case DatabaseError.SQLITE_CONSTRAINT: throw error default: failIfThrows { throw error } } } } } func storeKyberPreKey(_ record: LibSignalClient.KyberPreKeyRecord, id: UInt32, context: any StoreContext) throws { // This is currently unused and can't be implemented properly. owsFail("Not supported.") } } #if TESTABLE_BUILD protocol WritablePreKeyStore { func storePreKey(_ record: LibSignalClient.PreKeyRecord, replacedAt: Date?, context: any StoreContext) throws } extension WritablePreKeyStore where Self: LibSignalClient.PreKeyStore { func storePreKey(_ record: LibSignalClient.PreKeyRecord, replacedAt: Date?, context: any StoreContext) throws { try storePreKey(record, id: record.id, context: context) } } extension PreKeyStoreForIdentity: WritablePreKeyStore { func storePreKey(_ record: LibSignalClient.PreKeyRecord, replacedAt: Date?, context: any StoreContext) throws { owsPrecondition(replacedAt == nil) upsertPreKeyRecord(record.serialize(), keyId: record.id, in: .oneTime, isOneTime: true, tx: context.asTransaction) } } protocol WritableSignedPreKeyStore { func storeSignedPreKey(_ record: LibSignalClient.SignedPreKeyRecord, replacedAt: Date?, context: any StoreContext) throws } extension WritableSignedPreKeyStore where Self: LibSignalClient.SignedPreKeyStore { func storeSignedPreKey(_ record: LibSignalClient.SignedPreKeyRecord, replacedAt: Date?, context: any StoreContext) throws { owsPrecondition(replacedAt == nil) try storeSignedPreKey(record, id: record.id, context: context) } } extension PreKeyStoreForIdentity: WritableSignedPreKeyStore { func storeSignedPreKey(_ record: LibSignalClient.SignedPreKeyRecord, replacedAt: Date?, context: any StoreContext) throws { upsertPreKeyRecord(record.serialize(), keyId: record.id, in: .signed, isOneTime: false, tx: context.asTransaction) } } protocol WritableKyberPreKeyStore { func storeKyberPreKey(_ record: LibSignalClient.KyberPreKeyRecord, isOneTime: Bool, replacedAt: Date?, context: any StoreContext) throws } extension WritableKyberPreKeyStore where Self: LibSignalClient.KyberPreKeyStore { func storeKyberPreKey(_ record: LibSignalClient.KyberPreKeyRecord, isOneTime: Bool, replacedAt: Date?, context: any StoreContext) throws { try storeKyberPreKey(record, id: record.id, context: context) } } extension PreKeyStoreForIdentity: WritableKyberPreKeyStore { func storeKyberPreKey(_ record: LibSignalClient.KyberPreKeyRecord, isOneTime: Bool, replacedAt: Date?, context: any StoreContext) throws { owsPrecondition(replacedAt == nil) upsertPreKeyRecord(record.serialize(), keyId: record.id, in: .kyber, isOneTime: isOneTime, tx: context.asTransaction) } } #endif