Add writeWithRollbackIfThrows method to DB

This commit is contained in:
Sasha Weiss 2025-06-30 10:15:07 -07:00 committed by GitHub
parent 9a48618897
commit d65ae767ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 225 additions and 10 deletions

View File

@ -108,6 +108,16 @@ public final class InMemoryDB: DB {
return try write(file: file, function: function, line: line, block: block)
}
public func awaitableWriteWithRollbackIfThrows<T, E>(
file: String,
function: String,
line: Int,
block: (DBWriteTransaction) throws(E) -> T
) async throws(E) -> T {
await Task.yield()
return try writeWithRollbackIfThrows(file: file, function: function, line: line, block: block)
}
public func awaitableWriteWithTxCompletion<T>(
file: String,
function: String,
@ -151,7 +161,24 @@ public final class InMemoryDB: DB {
line: Int,
block: (DBWriteTransaction) throws(E) -> T
) throws(E) -> T {
return try _writeCommitIfThrows(block: block, rescue: { err throws(E) in throw err })
return try _writeWithTxCompletionIfThrows(
block: block,
completionIfThrows: .commit(()),
rescue: { err throws(E) in throw err }
)
}
public func writeWithRollbackIfThrows<T, E>(
file: String,
function: String,
line: Int,
block: (DBWriteTransaction) throws(E) -> T
) throws(E) -> T {
return try _writeWithTxCompletionIfThrows(
block: block,
completionIfThrows: .rollback(()),
rescue: { err throws(E) in throw err }
)
}
public func writeWithTxCompletion<T>(
@ -163,20 +190,21 @@ public final class InMemoryDB: DB {
return _writeWithTxCompletion(block: block)
}
private func _writeCommitIfThrows<T, E>(
private func _writeWithTxCompletionIfThrows<T, E>(
block: (DBWriteTransaction) throws(E) -> T,
rescue: (E) throws(E) -> Never
completionIfThrows: TransactionCompletion<Void>,
rescue: (E) throws(E) -> Never,
) throws(E) -> T {
var result: T!
var thrown: E?
_writeWithTxCompletion { tx in
do throws(E) {
result = try block(tx)
return .commit(())
} catch {
thrown = error
return completionIfThrows
}
// Always commit, regardless of thrown errors.
return .commit(())
}
if let thrown {
try rescue(thrown)

View File

@ -345,11 +345,28 @@ public class SDSDatabaseStorage: NSObject, DB {
line: Int,
block: (DBWriteTransaction) throws(E) -> T
) throws(E) -> T {
return try _writeCommitIfThrows(
return try _writeWithTxCompletionIfThrows(
file: file,
function: function,
line: line,
isAwaitableWrite: false,
completionIfThrows: .commit(()),
block: block,
)
}
public func writeWithRollbackIfThrows<T, E>(
file: String,
function: String,
line: Int,
block: (DBWriteTransaction) throws(E) -> T
) throws(E) -> T {
return try _writeWithTxCompletionIfThrows(
file: file,
function: function,
line: line,
isAwaitableWrite: false,
completionIfThrows: .rollback(()),
block: block,
)
}
@ -373,11 +390,12 @@ public class SDSDatabaseStorage: NSObject, DB {
}
}
private func _writeCommitIfThrows<T, E>(
private func _writeWithTxCompletionIfThrows<T, E>(
file: String,
function: String,
line: Int,
isAwaitableWrite: Bool,
completionIfThrows: TransactionCompletion<Void>,
block: (DBWriteTransaction) throws(E) -> T,
) throws(E) -> T {
var value: T!
@ -391,11 +409,11 @@ public class SDSDatabaseStorage: NSObject, DB {
) { tx in
do throws(E) {
value = try block(tx)
return .commit(())
} catch {
thrown = error
return completionIfThrows
}
// Always commit regardless of thrown errors.
return .commit(())
}
} catch {
owsFail("error: \(error.grdbErrorForLogging)")
@ -495,7 +513,32 @@ public class SDSDatabaseStorage: NSObject, DB {
block: (DBWriteTransaction) throws(E) -> T
) async throws(E) -> T {
return try await self.awaitableWriteQueue.runWithoutTaskCancellationHandler { () throws(E) -> T in
return try self._writeCommitIfThrows(file: file, function: function, line: line, isAwaitableWrite: true, block: block)
return try self._writeWithTxCompletionIfThrows(
file: file,
function: function,
line: line,
isAwaitableWrite: true,
completionIfThrows: .commit(()),
block: block
)
}
}
public func awaitableWriteWithRollbackIfThrows<T, E>(
file: String,
function: String,
line: Int,
block: (DBWriteTransaction) throws(E) -> T
) async throws(E) -> T {
return try await self.awaitableWriteQueue.runWithoutTaskCancellationHandler { () throws(E) -> T in
return try self._writeWithTxCompletionIfThrows(
file: file,
function: function,
line: line,
isAwaitableWrite: true,
completionIfThrows: .rollback(()),
block: block
)
}
}

View File

@ -45,6 +45,13 @@ public protocol DB {
block: (DBWriteTransaction) throws(E) -> T
) async throws(E) -> T
func awaitableWriteWithRollbackIfThrows<T, E>(
file: String,
function: String,
line: Int,
block: (DBWriteTransaction) throws(E) -> T
) async throws(E) -> T
func awaitableWriteWithTxCompletion<T>(
file: String,
function: String,
@ -68,6 +75,13 @@ public protocol DB {
block: (DBWriteTransaction) throws(E) -> T
) throws(E) -> T
func writeWithRollbackIfThrows<T, E>(
file: String,
function: String,
line: Int,
block: (DBWriteTransaction) throws(E) -> T
) throws(E) -> T
func writeWithTxCompletion<T>(
file: String,
function: String,
@ -142,6 +156,15 @@ extension DB {
return try await awaitableWrite(file: file, function: function, line: line, block: block)
}
public func awaitableWriteWithRollbackIfThrows<T, E>(
file: String = #file,
function: String = #function,
line: Int = #line,
block: (DBWriteTransaction) throws(E) -> T
) async throws(E) -> T {
return try await awaitableWriteWithRollbackIfThrows(file: file, function: function, line: line, block: block)
}
public func awaitableWriteWithTxCompletion<T>(
file: String = #file,
function: String = #function,
@ -171,6 +194,15 @@ extension DB {
return try write(file: file, function: function, line: line, block: block)
}
public func writeWithRollbackIfThrows<T, E>(
file: String = #file,
function: String = #function,
line: Int = #line,
block: (DBWriteTransaction) throws(E) -> T
) throws(E) -> T {
return try writeWithRollbackIfThrows(file: file, function: function, line: line, block: block)
}
public func writeWithTxCompletion<T>(
file: String = #file,
function: String = #function,

View File

@ -54,6 +54,62 @@ class SDSDatabaseStorageRollbackTest: SSKBaseTest {
}
}
func test_writeWithRollbackIfThrows() {
try? databaseStorage.writeWithRollbackIfThrows { tx in
XCTAssertFalse(kvStore.getBool(key, defaultValue: false, transaction: tx))
kvStore.setBool(true, key: key, transaction: tx)
throw SomeError()
}
// Should have rolled back.
databaseStorage.read { tx in
XCTAssertFalse(kvStore.getBool(key, defaultValue: false, transaction: tx))
}
// Run it again but catch the throw this time.
databaseStorage.writeWithRollbackIfThrows { tx in
do {
kvStore.setBool(true, key: key, transaction: tx)
throw SomeError()
} catch {
// Suppress error
}
}
// Should NOT have rolled back.
databaseStorage.read { tx in
XCTAssertTrue(kvStore.getBool(key, defaultValue: false, transaction: tx))
}
}
func test_writeWithRollbackIfThrows_async() async {
try? await databaseStorage.awaitableWriteWithRollbackIfThrows { tx in
XCTAssertFalse(kvStore.getBool(key, defaultValue: false, transaction: tx))
kvStore.setBool(true, key: key, transaction: tx)
throw SomeError()
}
// Should have rolled back.
databaseStorage.read { tx in
XCTAssertFalse(kvStore.getBool(key, defaultValue: false, transaction: tx))
}
// Run it again but catch the throw this time.
await databaseStorage.awaitableWriteWithRollbackIfThrows { tx in
do {
kvStore.setBool(true, key: key, transaction: tx)
throw SomeError()
} catch {
// Suppress error
}
}
// Should NOT have rolled back.
databaseStorage.read { tx in
XCTAssertTrue(kvStore.getBool(key, defaultValue: false, transaction: tx))
}
}
func test_writeWithTxCompletionRollback() {
databaseStorage.writeWithTxCompletion { tx in
XCTAssertFalse(kvStore.getBool(key, defaultValue: false, transaction: tx))
@ -165,6 +221,62 @@ class InMemoryDBRollbackTest: XCTestCase {
}
}
func test_writeWithRollbackIfThrows() {
try? db.writeWithRollbackIfThrows { tx in
XCTAssertFalse(getBool(tx: tx))
setBool(true, tx: tx)
throw SomeError()
}
// Should have rolled back.
db.read { tx in
XCTAssertFalse(getBool(tx: tx))
}
// Run it again but catch the throw this time.
db.writeWithRollbackIfThrows { tx in
do {
setBool(true, tx: tx)
throw SomeError()
} catch {
// Suppress error
}
}
// Should NOT have rolled back.
db.read { tx in
XCTAssertTrue(getBool(tx: tx))
}
}
func test_writeWithRollbackIfThrows_async() async {
try? await db.awaitableWriteWithRollbackIfThrows { tx in
XCTAssertFalse(getBool(tx: tx))
setBool(true, tx: tx)
throw SomeError()
}
// Should have rolled back.
db.read { tx in
XCTAssertFalse(getBool(tx: tx))
}
// Run it again but catch the throw this time.
await db.awaitableWriteWithRollbackIfThrows { tx in
do {
setBool(true, tx: tx)
throw SomeError()
} catch {
// Suppress error
}
}
// Should NOT have rolled back.
db.read { tx in
XCTAssertTrue(getBool(tx: tx))
}
}
func test_writeWithTxCompletionRollback() {
db.writeWithTxCompletion { tx in
XCTAssertFalse(getBool(tx: tx))