302 lines
12 KiB
Swift
302 lines
12 KiB
Swift
//
|
|
// Copyright 2021 Signal Messenger, LLC
|
|
// SPDX-License-Identifier: AGPL-3.0-only
|
|
//
|
|
|
|
import CryptoKit
|
|
import Foundation
|
|
public import GRDB
|
|
|
|
public class ThreadAssociatedData: NSObject, Codable, FetchableRecord, PersistableRecord {
|
|
public static let databaseTableName = "thread_associated_data"
|
|
|
|
public private(set) var id: Int64?
|
|
|
|
public let threadUniqueId: String
|
|
|
|
public private(set) var isArchived: Bool = false
|
|
public private(set) var isMarkedUnread: Bool = false
|
|
public private(set) var mutedUntilTimestamp: UInt64 = 0
|
|
public private(set) var audioPlaybackRate: Float = 1
|
|
|
|
// The last group name that was set by the local user.
|
|
// Nil if it has never been set by the local user.
|
|
public private(set) var lastVerifiedGroupNameHash: Data?
|
|
|
|
public var isMuted: Bool { mutedUntilTimestamp > Date.ows_millisecondTimestamp() }
|
|
|
|
public var mutedUntilDate: Date? {
|
|
guard mutedUntilTimestamp > 0 else { return nil }
|
|
return Date(millisecondsSince1970: mutedUntilTimestamp)
|
|
}
|
|
|
|
static func groupNameVerificationHash(groupName: String?) -> Data? {
|
|
guard let groupName, let groupNameData = groupName.data(using: .utf8) else {
|
|
return nil
|
|
}
|
|
var sha = SHA256()
|
|
sha.update(data: groupNameData)
|
|
return Data(sha.finalize())
|
|
}
|
|
|
|
public func isGroupNameVerified(groupName: String) -> Bool {
|
|
return lastVerifiedGroupNameHash == Self.groupNameVerificationHash(groupName: groupName)
|
|
}
|
|
|
|
public static var alwaysMutedTimestamp: UInt64 { UInt64(LLONG_MAX) }
|
|
|
|
public static func fetchOrDefault(
|
|
for thread: TSThread,
|
|
transaction: DBReadTransaction,
|
|
) -> ThreadAssociatedData {
|
|
fetchOrDefault(for: thread.uniqueId, ignoreMissing: false, transaction: transaction)
|
|
}
|
|
|
|
public static func fetchOrDefault(
|
|
for threadUniqueId: String,
|
|
transaction: DBReadTransaction,
|
|
) -> ThreadAssociatedData {
|
|
fetchOrDefault(for: threadUniqueId, ignoreMissing: false, transaction: transaction)
|
|
}
|
|
|
|
public static func fetchOrDefault(
|
|
for thread: TSThread,
|
|
ignoreMissing: Bool,
|
|
transaction: DBReadTransaction,
|
|
) -> ThreadAssociatedData {
|
|
fetchOrDefault(for: thread.uniqueId, ignoreMissing: ignoreMissing, transaction: transaction)
|
|
}
|
|
|
|
public static func fetchOrDefault(
|
|
for threadUniqueId: String,
|
|
ignoreMissing: Bool,
|
|
transaction: DBReadTransaction,
|
|
) -> ThreadAssociatedData {
|
|
DependenciesBridge.shared.threadAssociatedDataStore.fetchOrDefault(
|
|
for: threadUniqueId,
|
|
ignoreMissing: ignoreMissing || CurrentAppContext().isRunningTests || threadUniqueId == "MockThread" || threadUniqueId == "MockGroupThread",
|
|
tx: transaction,
|
|
)
|
|
}
|
|
|
|
public static func fetch(
|
|
for threadUniqueId: String,
|
|
transaction: DBReadTransaction,
|
|
) -> ThreadAssociatedData? {
|
|
do {
|
|
return try Self.filter(Column("threadUniqueId") == threadUniqueId).fetchOne(transaction.database)
|
|
} catch {
|
|
owsFailDebug("Failed to read associated data \(error)")
|
|
return nil
|
|
}
|
|
}
|
|
|
|
public static func create(for threadUniqueId: String, transaction: DBWriteTransaction) {
|
|
let threadAssociatedDataStore = DependenciesBridge.shared.threadAssociatedDataStore
|
|
guard threadAssociatedDataStore.fetch(for: threadUniqueId, tx: transaction) == nil else {
|
|
return
|
|
}
|
|
do {
|
|
try ThreadAssociatedData(threadUniqueId: threadUniqueId).insert(transaction.database)
|
|
} catch {
|
|
owsFailDebug("Unexpectedly failed to insert \(error)")
|
|
}
|
|
}
|
|
|
|
init(threadUniqueId: String) {
|
|
self.threadUniqueId = threadUniqueId
|
|
super.init()
|
|
}
|
|
|
|
public enum CodingKeys: String, CodingKey {
|
|
case id
|
|
case threadUniqueId
|
|
case isArchived
|
|
case isMarkedUnread
|
|
case mutedUntilTimestamp
|
|
case audioPlaybackRate
|
|
case lastVerifiedGroupNameHash
|
|
}
|
|
|
|
public required init(from decoder: any Decoder) throws {
|
|
let container = try decoder.container(keyedBy: CodingKeys.self)
|
|
self.id = try container.decodeIfPresent(Int64.self, forKey: .id)
|
|
self.threadUniqueId = try container.decode(String.self, forKey: .threadUniqueId)
|
|
if let isArchived = try container.decodeIfPresent(Bool.self, forKey: .isArchived) {
|
|
self.isArchived = isArchived
|
|
}
|
|
if let isMarkedUnread = try container.decodeIfPresent(Bool.self, forKey: .isMarkedUnread) {
|
|
self.isMarkedUnread = isMarkedUnread
|
|
}
|
|
if let mutedUntilTimestamp = try container.decodeIfPresent(Int64.self, forKey: .mutedUntilTimestamp) {
|
|
self.mutedUntilTimestamp = UInt64(bitPattern: mutedUntilTimestamp)
|
|
}
|
|
if let audioPlaybackRate = try container.decodeIfPresent(Float.self, forKey: .audioPlaybackRate) {
|
|
self.audioPlaybackRate = audioPlaybackRate
|
|
}
|
|
if let lastVerifiedGroupNameHash = try container.decodeIfPresent(Data.self, forKey: .lastVerifiedGroupNameHash) {
|
|
self.lastVerifiedGroupNameHash = lastVerifiedGroupNameHash
|
|
}
|
|
}
|
|
|
|
public func encode(to encoder: any Encoder) throws {
|
|
var container = encoder.container(keyedBy: CodingKeys.self)
|
|
try container.encode(self.id, forKey: .id)
|
|
try container.encode(self.threadUniqueId, forKey: .threadUniqueId)
|
|
try container.encode(self.isArchived, forKey: .isArchived)
|
|
try container.encode(self.isMarkedUnread, forKey: .isMarkedUnread)
|
|
try container.encode(Int64(bitPattern: self.mutedUntilTimestamp), forKey: .mutedUntilTimestamp)
|
|
try container.encode(self.audioPlaybackRate, forKey: .audioPlaybackRate)
|
|
try container.encodeIfPresent(lastVerifiedGroupNameHash, forKey: .lastVerifiedGroupNameHash)
|
|
}
|
|
|
|
public func didInsert(with rowID: Int64, for column: String?) {
|
|
self.id = rowID
|
|
}
|
|
|
|
public init(
|
|
threadUniqueId: String,
|
|
isArchived: Bool,
|
|
isMarkedUnread: Bool,
|
|
mutedUntilTimestamp: UInt64,
|
|
audioPlaybackRate: Float,
|
|
lastVerifiedGroupNameHash: Data?,
|
|
) {
|
|
self.threadUniqueId = threadUniqueId
|
|
self.isArchived = isArchived
|
|
self.isMarkedUnread = isMarkedUnread
|
|
self.mutedUntilTimestamp = mutedUntilTimestamp
|
|
self.audioPlaybackRate = audioPlaybackRate
|
|
self.lastVerifiedGroupNameHash = lastVerifiedGroupNameHash
|
|
super.init()
|
|
}
|
|
|
|
public func updateWith(
|
|
isArchived: Bool? = nil,
|
|
isMarkedUnread: Bool? = nil,
|
|
mutedUntilTimestamp: UInt64? = nil,
|
|
audioPlaybackRate: Float? = nil,
|
|
lastVerifiedGroupNameHash: Data? = nil,
|
|
updateStorageService: Bool,
|
|
transaction: DBWriteTransaction,
|
|
) {
|
|
guard
|
|
isArchived != nil
|
|
|| isMarkedUnread != nil
|
|
|| mutedUntilTimestamp != nil
|
|
|| audioPlaybackRate != nil
|
|
|| lastVerifiedGroupNameHash != nil
|
|
else {
|
|
return owsFailDebug("You must set one value")
|
|
}
|
|
|
|
updateWith(updateStorageService: updateStorageService, transaction: transaction) { associatedData in
|
|
if let isArchived {
|
|
associatedData.isArchived = isArchived
|
|
}
|
|
if let isMarkedUnread {
|
|
associatedData.isMarkedUnread = isMarkedUnread
|
|
}
|
|
if let mutedUntilTimestamp {
|
|
associatedData.mutedUntilTimestamp = mutedUntilTimestamp
|
|
}
|
|
if let audioPlaybackRate {
|
|
associatedData.audioPlaybackRate = audioPlaybackRate
|
|
}
|
|
if let lastVerifiedGroupNameHash {
|
|
associatedData.lastVerifiedGroupNameHash = lastVerifiedGroupNameHash
|
|
}
|
|
}
|
|
}
|
|
|
|
public func clear(isArchived clearIsArchived: Bool = false, isMarkedUnread clearIsMarkedUnread: Bool = false, updateStorageService: Bool, transaction: DBWriteTransaction) {
|
|
guard clearIsArchived || clearIsMarkedUnread else { return }
|
|
updateWith(updateStorageService: updateStorageService, transaction: transaction) { associatedData in
|
|
if clearIsArchived { associatedData.isArchived = false }
|
|
if clearIsMarkedUnread { associatedData.isMarkedUnread = false }
|
|
}
|
|
}
|
|
|
|
private func updateWith(updateStorageService: Bool, transaction: DBWriteTransaction, block: (ThreadAssociatedData) -> Void) {
|
|
block(self)
|
|
|
|
let threadAssociatedDataStore = DependenciesBridge.shared.threadAssociatedDataStore
|
|
let storedCopy = threadAssociatedDataStore.fetch(for: threadUniqueId, tx: transaction)
|
|
|
|
if let storedCopy, storedCopy !== self {
|
|
block(storedCopy)
|
|
}
|
|
|
|
if let storedCopy {
|
|
do {
|
|
try storedCopy.update(transaction.database)
|
|
} catch {
|
|
owsFailDebug("Unexpectedly failed to update \(error)")
|
|
}
|
|
} else {
|
|
do {
|
|
owsFailDebug("Could not update missing record.")
|
|
try insert(transaction.database)
|
|
} catch {
|
|
owsFailDebug("Unexpectedly failed to insert \(error)")
|
|
}
|
|
}
|
|
|
|
// If the thread model exists, make sure the UI is notified that it has changed.
|
|
if let thread = TSThread.fetchViaCache(uniqueId: threadUniqueId, transaction: transaction) {
|
|
SSKEnvironment.shared.databaseStorageRef.touch(thread: thread, shouldReindex: false, tx: transaction)
|
|
}
|
|
|
|
if updateStorageService {
|
|
guard let thread = TSThread.fetchViaCache(uniqueId: threadUniqueId, transaction: transaction) else {
|
|
return owsFailDebug("Unexpectedly missing thread for storage service update.")
|
|
}
|
|
|
|
if let groupThread = thread as? TSGroupThread {
|
|
SSKEnvironment.shared.storageServiceManagerRef.recordPendingUpdates(groupModel: groupThread.groupModel)
|
|
} else if let contactThread = thread as? TSContactThread {
|
|
SSKEnvironment.shared.storageServiceManagerRef.recordPendingUpdates(updatedAddresses: [contactThread.contactAddress])
|
|
} else if thread.isReleaseNotesThread {
|
|
// TODO: sync release notes mute state in storage service
|
|
} else {
|
|
owsFailDebug("Unexpected thread type")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
public extension TSThread {
|
|
func markAllAsRead(updateStorageService: Bool, transaction: DBWriteTransaction) {
|
|
markAllAsRead(transaction: transaction)
|
|
|
|
let associatedData = ThreadAssociatedData.fetchOrDefault(for: self, transaction: transaction)
|
|
associatedData.updateWith(isMarkedUnread: false, updateStorageService: updateStorageService, transaction: transaction)
|
|
}
|
|
|
|
private func markAllAsRead(transaction: DBWriteTransaction) {
|
|
let hasPendingMessageRequest = hasPendingMessageRequest(transaction: transaction)
|
|
let circumstance: OWSReceiptCircumstance = hasPendingMessageRequest
|
|
? .onThisDeviceWhilePendingMessageRequest
|
|
: .onThisDevice
|
|
|
|
let finder = InteractionFinder(threadUniqueId: uniqueId)
|
|
var cursor = finder.fetchAllUnreadMessages(transaction: transaction)
|
|
do {
|
|
while let message = try cursor.next() {
|
|
message.markAsRead(
|
|
atTimestamp: Date.ows_millisecondTimestamp(),
|
|
thread: self,
|
|
circumstance: circumstance,
|
|
shouldClearNotifications: true,
|
|
transaction: transaction,
|
|
)
|
|
}
|
|
} catch {
|
|
owsFailDebug("unexpected failure fetching unread messages: \(error)")
|
|
}
|
|
|
|
// Just to be defensive, we'll also check for unread messages.
|
|
owsAssertDebug(finder.unreadCount(transaction: transaction) == 0)
|
|
}
|
|
}
|