286 lines
11 KiB
Swift
286 lines
11 KiB
Swift
//
|
|
// Copyright 2025 Signal Messenger, LLC
|
|
// SPDX-License-Identifier: AGPL-3.0-only
|
|
//
|
|
|
|
import Foundation
|
|
import GRDB
|
|
import LibSignalClient
|
|
|
|
struct PreKeyStore {
|
|
enum Error: Swift.Error {
|
|
case noPreKeyWithId(UInt32)
|
|
}
|
|
|
|
let aciStore: PreKeyStoreForIdentity
|
|
let pniStore: PreKeyStoreForIdentity
|
|
|
|
init() {
|
|
self.aciStore = PreKeyStoreForIdentity(identity: .aci)
|
|
self.pniStore = PreKeyStoreForIdentity(identity: .pni)
|
|
}
|
|
|
|
func forIdentity(_ identity: OWSIdentity) -> PreKeyStoreForIdentity {
|
|
switch identity {
|
|
case .aci: aciStore
|
|
case .pni: pniStore
|
|
}
|
|
}
|
|
|
|
func removeAll(tx: DBWriteTransaction) {
|
|
Logger.info("")
|
|
failIfThrows {
|
|
_ = try PreKeyRecord.deleteAll(tx.database)
|
|
}
|
|
}
|
|
|
|
func allocatePreKeyIds(
|
|
in metadataStore: KeyValueStore,
|
|
lastPreKeyIdKey: String,
|
|
count: Int,
|
|
tx: DBWriteTransaction,
|
|
) -> ClosedRange<UInt32> {
|
|
let lastPreKeyId = metadataStore.getInt(lastPreKeyIdKey, transaction: tx).flatMap(UInt32.init(exactly:))
|
|
let preKeyIds = PreKeyId.nextPreKeyIds(lastPreKeyId: lastPreKeyId, count: count)
|
|
metadataStore.setInt(Int(preKeyIds.upperBound), key: lastPreKeyIdKey, transaction: tx)
|
|
return preKeyIds
|
|
}
|
|
|
|
func setReplacedAtIfNil(
|
|
to now: Date,
|
|
in namespace: PreKeyRecord.Namespace,
|
|
identity: OWSIdentity,
|
|
isOneTime: Bool,
|
|
exceptFor exceptForPreKeyIds: [UInt32],
|
|
tx: DBWriteTransaction,
|
|
) {
|
|
let keyIdColumn = Column(PreKeyRecord.CodingKeys.keyId.rawValue)
|
|
let replacedAtColumn = Column(PreKeyRecord.CodingKeys.replacedAt.rawValue)
|
|
let isOneTimeColumn = Column(PreKeyRecord.CodingKeys.isOneTime.rawValue)
|
|
failIfThrows {
|
|
_ = try PreKeyRecord.baseQuery(in: namespace, identity: identity)
|
|
.filter(isOneTimeColumn == isOneTime)
|
|
.filter(replacedAtColumn == nil)
|
|
.filter(!exceptForPreKeyIds.contains(keyIdColumn))
|
|
.updateAll(tx.database, [replacedAtColumn.set(to: Int64(now.timeIntervalSince1970))])
|
|
}
|
|
}
|
|
|
|
func cullPreKeys(gracePeriod: TimeInterval, tx: DBWriteTransaction) {
|
|
let now = Date().timeIntervalSince1970
|
|
let delay = PreKeyManagerImpl.Constants.maxUnacknowledgedSessionAge + gracePeriod
|
|
let replacedAt = Column(PreKeyRecord.CodingKeys.replacedAt.rawValue)
|
|
failIfThrows {
|
|
var rowIds = [Int64]()
|
|
let query = PreKeyRecord.filter(replacedAt < Int64(now - delay) || replacedAt > Int64(now + delay))
|
|
let cursor = try query.fetchCursor(tx.database)
|
|
while let preKey = try cursor.next() {
|
|
Logger.info("removing prekey \(preKey.namespace) \(preKey.keyId), replacedAt \(preKey.replacedAt!)")
|
|
rowIds.append(preKey.rowId)
|
|
}
|
|
for rowId in rowIds {
|
|
try PreKeyRecord.deleteOne(tx.database, key: rowId)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
class PreKeyStoreForIdentity {
|
|
private let identity: OWSIdentity
|
|
|
|
init(identity: OWSIdentity) {
|
|
self.identity = identity
|
|
}
|
|
|
|
private func baseQuery(in namespace: PreKeyRecord.Namespace) -> QueryInterfaceRequest<PreKeyRecord> {
|
|
return PreKeyRecord.baseQuery(in: namespace, identity: self.identity)
|
|
}
|
|
|
|
func fetchPreKey(in namespace: PreKeyRecord.Namespace, for keyId: UInt32, tx: DBReadTransaction) -> PreKeyRecord? {
|
|
failIfThrows {
|
|
do {
|
|
return try baseQuery(in: namespace)
|
|
.filter(Column(PreKeyRecord.CodingKeys.keyId.rawValue) == keyId)
|
|
.fetchOne(tx.database)
|
|
} catch {
|
|
throw error.grdbErrorForLogging
|
|
}
|
|
}
|
|
}
|
|
|
|
private func fetchSerializedRecord(in namespace: PreKeyRecord.Namespace, for keyId: UInt32, tx: DBReadTransaction) throws -> Data {
|
|
let preKey = fetchPreKey(in: namespace, for: keyId, tx: tx)
|
|
guard let serializedRecord = preKey?.serializedRecord else {
|
|
throw PreKeyStore.Error.noPreKeyWithId(keyId)
|
|
}
|
|
return serializedRecord
|
|
}
|
|
|
|
func upsertPreKeyRecord(
|
|
_ serializedRecord: Data,
|
|
keyId: UInt32,
|
|
in namespace: PreKeyRecord.Namespace,
|
|
isOneTime: Bool,
|
|
tx: DBWriteTransaction,
|
|
) {
|
|
failIfThrows {
|
|
do {
|
|
// Key IDs intentionally aren't large enough to avoid conflicts when
|
|
// sampling randomly. Clients don't generate conflicting keys, though
|
|
// certain operations (e.g., change number) may produce harmless conflicts.
|
|
// We use "OR REPLACE" to keep the latest key if such a conflict occurs.
|
|
try tx.database.execute(
|
|
sql: """
|
|
INSERT OR REPLACE INTO \(PreKeyRecord.databaseTableName) (
|
|
\(PreKeyRecord.CodingKeys.namespace.rawValue),
|
|
\(PreKeyRecord.CodingKeys.identity.rawValue),
|
|
\(PreKeyRecord.CodingKeys.keyId.rawValue),
|
|
\(PreKeyRecord.CodingKeys.isOneTime.rawValue),
|
|
\(PreKeyRecord.CodingKeys.serializedRecord.rawValue)
|
|
) VALUES (?, ?, ?, ?, ?)
|
|
""",
|
|
arguments: [namespace.rawValue, self.identity.rawValue, keyId, isOneTime, serializedRecord],
|
|
)
|
|
} catch {
|
|
throw error.grdbErrorForLogging
|
|
}
|
|
}
|
|
}
|
|
|
|
func removePreKey(in namespace: PreKeyRecord.Namespace, keyId: UInt32, tx: DBWriteTransaction) {
|
|
let keyIdColumn = Column(PreKeyRecord.CodingKeys.keyId.rawValue)
|
|
failIfThrows {
|
|
_ = try baseQuery(in: namespace).filter(keyIdColumn == keyId).deleteAll(tx.database)
|
|
}
|
|
}
|
|
|
|
#if TESTABLE_BUILD
|
|
|
|
func fetchCount(in namespace: PreKeyRecord.Namespace, isOneTime: Bool, tx: DBReadTransaction) throws -> Int {
|
|
return try baseQuery(in: namespace)
|
|
.filter(Column(PreKeyRecord.CodingKeys.isOneTime.rawValue) == isOneTime)
|
|
.fetchCount(tx.database)
|
|
}
|
|
|
|
#endif
|
|
}
|
|
|
|
extension PreKeyStoreForIdentity: LibSignalClient.PreKeyStore {
|
|
func loadPreKey(id: UInt32, context: any StoreContext) throws -> LibSignalClient.PreKeyRecord {
|
|
return try LibSignalClient.PreKeyRecord(bytes: fetchSerializedRecord(in: .oneTime, for: id, tx: context.asTransaction))
|
|
}
|
|
|
|
func removePreKey(id: UInt32, context: any StoreContext) throws {
|
|
removePreKey(in: .oneTime, keyId: id, tx: context.asTransaction)
|
|
}
|
|
|
|
func storePreKey(_ record: LibSignalClient.PreKeyRecord, id: UInt32, context: any StoreContext) throws {
|
|
// This is currently unused (and needs `replacedAt` support).
|
|
owsFail("Not supported.")
|
|
}
|
|
}
|
|
|
|
extension PreKeyStoreForIdentity: LibSignalClient.SignedPreKeyStore {
|
|
func loadSignedPreKey(id: UInt32, context: any StoreContext) throws -> LibSignalClient.SignedPreKeyRecord {
|
|
return try LibSignalClient.SignedPreKeyRecord(bytes: fetchSerializedRecord(in: .signed, for: id, tx: context.asTransaction))
|
|
}
|
|
|
|
func storeSignedPreKey(_ record: LibSignalClient.SignedPreKeyRecord, id: UInt32, context: any StoreContext) throws {
|
|
// This is currently unused (and needs `replacedAt` support).
|
|
owsFail("Not supported.")
|
|
}
|
|
}
|
|
|
|
extension PreKeyStoreForIdentity: LibSignalClient.KyberPreKeyStore {
|
|
func loadKyberPreKey(id: UInt32, context: any StoreContext) throws -> LibSignalClient.KyberPreKeyRecord {
|
|
return try LibSignalClient.KyberPreKeyRecord(bytes: fetchSerializedRecord(in: .kyber, for: id, tx: context.asTransaction))
|
|
}
|
|
|
|
func markKyberPreKeyUsed(id keyId: UInt32, signedPreKeyId: UInt32, baseKey: PublicKey, context: any StoreContext) throws {
|
|
let tx = context.asTransaction
|
|
guard let preKey = fetchPreKey(in: .kyber, for: keyId, tx: tx) else {
|
|
throw PreKeyStore.Error.noPreKeyWithId(keyId)
|
|
}
|
|
if preKey.isOneTime {
|
|
removePreKey(in: .kyber, keyId: keyId, tx: tx)
|
|
} else {
|
|
do {
|
|
try KyberPreKeyUseRecord(
|
|
kyberRowId: preKey.rowId,
|
|
signedPreKeyIdentity: self.identity,
|
|
signedPreKeyId: signedPreKeyId,
|
|
baseKey: baseKey.serialize(),
|
|
).insert(tx.database)
|
|
} catch {
|
|
let error = error.grdbErrorForLogging
|
|
switch error {
|
|
case DatabaseError.SQLITE_CONSTRAINT:
|
|
throw error
|
|
default:
|
|
failIfThrows { throw error }
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func storeKyberPreKey(_ record: LibSignalClient.KyberPreKeyRecord, id: UInt32, context: any StoreContext) throws {
|
|
// This is currently unused and can't be implemented properly.
|
|
owsFail("Not supported.")
|
|
}
|
|
}
|
|
|
|
#if TESTABLE_BUILD
|
|
|
|
protocol WritablePreKeyStore {
|
|
func storePreKey(_ record: LibSignalClient.PreKeyRecord, replacedAt: Date?, context: any StoreContext) throws
|
|
}
|
|
|
|
extension WritablePreKeyStore where Self: LibSignalClient.PreKeyStore {
|
|
func storePreKey(_ record: LibSignalClient.PreKeyRecord, replacedAt: Date?, context: any StoreContext) throws {
|
|
try storePreKey(record, id: record.id, context: context)
|
|
}
|
|
}
|
|
|
|
extension PreKeyStoreForIdentity: WritablePreKeyStore {
|
|
func storePreKey(_ record: LibSignalClient.PreKeyRecord, replacedAt: Date?, context: any StoreContext) throws {
|
|
owsPrecondition(replacedAt == nil)
|
|
upsertPreKeyRecord(record.serialize(), keyId: record.id, in: .oneTime, isOneTime: true, tx: context.asTransaction)
|
|
}
|
|
}
|
|
|
|
protocol WritableSignedPreKeyStore {
|
|
func storeSignedPreKey(_ record: LibSignalClient.SignedPreKeyRecord, replacedAt: Date?, context: any StoreContext) throws
|
|
}
|
|
|
|
extension WritableSignedPreKeyStore where Self: LibSignalClient.SignedPreKeyStore {
|
|
func storeSignedPreKey(_ record: LibSignalClient.SignedPreKeyRecord, replacedAt: Date?, context: any StoreContext) throws {
|
|
owsPrecondition(replacedAt == nil)
|
|
try storeSignedPreKey(record, id: record.id, context: context)
|
|
}
|
|
}
|
|
|
|
extension PreKeyStoreForIdentity: WritableSignedPreKeyStore {
|
|
func storeSignedPreKey(_ record: LibSignalClient.SignedPreKeyRecord, replacedAt: Date?, context: any StoreContext) throws {
|
|
upsertPreKeyRecord(record.serialize(), keyId: record.id, in: .signed, isOneTime: false, tx: context.asTransaction)
|
|
}
|
|
}
|
|
|
|
protocol WritableKyberPreKeyStore {
|
|
func storeKyberPreKey(_ record: LibSignalClient.KyberPreKeyRecord, isOneTime: Bool, replacedAt: Date?, context: any StoreContext) throws
|
|
}
|
|
|
|
extension WritableKyberPreKeyStore where Self: LibSignalClient.KyberPreKeyStore {
|
|
func storeKyberPreKey(_ record: LibSignalClient.KyberPreKeyRecord, isOneTime: Bool, replacedAt: Date?, context: any StoreContext) throws {
|
|
try storeKyberPreKey(record, id: record.id, context: context)
|
|
}
|
|
}
|
|
|
|
extension PreKeyStoreForIdentity: WritableKyberPreKeyStore {
|
|
func storeKyberPreKey(_ record: LibSignalClient.KyberPreKeyRecord, isOneTime: Bool, replacedAt: Date?, context: any StoreContext) throws {
|
|
owsPrecondition(replacedAt == nil)
|
|
upsertPreKeyRecord(record.serialize(), keyId: record.id, in: .kyber, isOneTime: isOneTime, tx: context.asTransaction)
|
|
}
|
|
}
|
|
|
|
#endif
|