Add writeWithRollbackIfThrows method to DB
This commit is contained in:
parent
9a48618897
commit
d65ae767ba
@ -108,6 +108,16 @@ public final class InMemoryDB: DB {
|
||||
return try write(file: file, function: function, line: line, block: block)
|
||||
}
|
||||
|
||||
public func awaitableWriteWithRollbackIfThrows<T, E>(
|
||||
file: String,
|
||||
function: String,
|
||||
line: Int,
|
||||
block: (DBWriteTransaction) throws(E) -> T
|
||||
) async throws(E) -> T {
|
||||
await Task.yield()
|
||||
return try writeWithRollbackIfThrows(file: file, function: function, line: line, block: block)
|
||||
}
|
||||
|
||||
public func awaitableWriteWithTxCompletion<T>(
|
||||
file: String,
|
||||
function: String,
|
||||
@ -151,7 +161,24 @@ public final class InMemoryDB: DB {
|
||||
line: Int,
|
||||
block: (DBWriteTransaction) throws(E) -> T
|
||||
) throws(E) -> T {
|
||||
return try _writeCommitIfThrows(block: block, rescue: { err throws(E) in throw err })
|
||||
return try _writeWithTxCompletionIfThrows(
|
||||
block: block,
|
||||
completionIfThrows: .commit(()),
|
||||
rescue: { err throws(E) in throw err }
|
||||
)
|
||||
}
|
||||
|
||||
public func writeWithRollbackIfThrows<T, E>(
|
||||
file: String,
|
||||
function: String,
|
||||
line: Int,
|
||||
block: (DBWriteTransaction) throws(E) -> T
|
||||
) throws(E) -> T {
|
||||
return try _writeWithTxCompletionIfThrows(
|
||||
block: block,
|
||||
completionIfThrows: .rollback(()),
|
||||
rescue: { err throws(E) in throw err }
|
||||
)
|
||||
}
|
||||
|
||||
public func writeWithTxCompletion<T>(
|
||||
@ -163,20 +190,21 @@ public final class InMemoryDB: DB {
|
||||
return _writeWithTxCompletion(block: block)
|
||||
}
|
||||
|
||||
private func _writeCommitIfThrows<T, E>(
|
||||
private func _writeWithTxCompletionIfThrows<T, E>(
|
||||
block: (DBWriteTransaction) throws(E) -> T,
|
||||
rescue: (E) throws(E) -> Never
|
||||
completionIfThrows: TransactionCompletion<Void>,
|
||||
rescue: (E) throws(E) -> Never,
|
||||
) throws(E) -> T {
|
||||
var result: T!
|
||||
var thrown: E?
|
||||
_writeWithTxCompletion { tx in
|
||||
do throws(E) {
|
||||
result = try block(tx)
|
||||
return .commit(())
|
||||
} catch {
|
||||
thrown = error
|
||||
return completionIfThrows
|
||||
}
|
||||
// Always commit, regardless of thrown errors.
|
||||
return .commit(())
|
||||
}
|
||||
if let thrown {
|
||||
try rescue(thrown)
|
||||
|
||||
@ -345,11 +345,28 @@ public class SDSDatabaseStorage: NSObject, DB {
|
||||
line: Int,
|
||||
block: (DBWriteTransaction) throws(E) -> T
|
||||
) throws(E) -> T {
|
||||
return try _writeCommitIfThrows(
|
||||
return try _writeWithTxCompletionIfThrows(
|
||||
file: file,
|
||||
function: function,
|
||||
line: line,
|
||||
isAwaitableWrite: false,
|
||||
completionIfThrows: .commit(()),
|
||||
block: block,
|
||||
)
|
||||
}
|
||||
|
||||
public func writeWithRollbackIfThrows<T, E>(
|
||||
file: String,
|
||||
function: String,
|
||||
line: Int,
|
||||
block: (DBWriteTransaction) throws(E) -> T
|
||||
) throws(E) -> T {
|
||||
return try _writeWithTxCompletionIfThrows(
|
||||
file: file,
|
||||
function: function,
|
||||
line: line,
|
||||
isAwaitableWrite: false,
|
||||
completionIfThrows: .rollback(()),
|
||||
block: block,
|
||||
)
|
||||
}
|
||||
@ -373,11 +390,12 @@ public class SDSDatabaseStorage: NSObject, DB {
|
||||
}
|
||||
}
|
||||
|
||||
private func _writeCommitIfThrows<T, E>(
|
||||
private func _writeWithTxCompletionIfThrows<T, E>(
|
||||
file: String,
|
||||
function: String,
|
||||
line: Int,
|
||||
isAwaitableWrite: Bool,
|
||||
completionIfThrows: TransactionCompletion<Void>,
|
||||
block: (DBWriteTransaction) throws(E) -> T,
|
||||
) throws(E) -> T {
|
||||
var value: T!
|
||||
@ -391,11 +409,11 @@ public class SDSDatabaseStorage: NSObject, DB {
|
||||
) { tx in
|
||||
do throws(E) {
|
||||
value = try block(tx)
|
||||
return .commit(())
|
||||
} catch {
|
||||
thrown = error
|
||||
return completionIfThrows
|
||||
}
|
||||
// Always commit regardless of thrown errors.
|
||||
return .commit(())
|
||||
}
|
||||
} catch {
|
||||
owsFail("error: \(error.grdbErrorForLogging)")
|
||||
@ -495,7 +513,32 @@ public class SDSDatabaseStorage: NSObject, DB {
|
||||
block: (DBWriteTransaction) throws(E) -> T
|
||||
) async throws(E) -> T {
|
||||
return try await self.awaitableWriteQueue.runWithoutTaskCancellationHandler { () throws(E) -> T in
|
||||
return try self._writeCommitIfThrows(file: file, function: function, line: line, isAwaitableWrite: true, block: block)
|
||||
return try self._writeWithTxCompletionIfThrows(
|
||||
file: file,
|
||||
function: function,
|
||||
line: line,
|
||||
isAwaitableWrite: true,
|
||||
completionIfThrows: .commit(()),
|
||||
block: block
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
public func awaitableWriteWithRollbackIfThrows<T, E>(
|
||||
file: String,
|
||||
function: String,
|
||||
line: Int,
|
||||
block: (DBWriteTransaction) throws(E) -> T
|
||||
) async throws(E) -> T {
|
||||
return try await self.awaitableWriteQueue.runWithoutTaskCancellationHandler { () throws(E) -> T in
|
||||
return try self._writeWithTxCompletionIfThrows(
|
||||
file: file,
|
||||
function: function,
|
||||
line: line,
|
||||
isAwaitableWrite: true,
|
||||
completionIfThrows: .rollback(()),
|
||||
block: block
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -45,6 +45,13 @@ public protocol DB {
|
||||
block: (DBWriteTransaction) throws(E) -> T
|
||||
) async throws(E) -> T
|
||||
|
||||
func awaitableWriteWithRollbackIfThrows<T, E>(
|
||||
file: String,
|
||||
function: String,
|
||||
line: Int,
|
||||
block: (DBWriteTransaction) throws(E) -> T
|
||||
) async throws(E) -> T
|
||||
|
||||
func awaitableWriteWithTxCompletion<T>(
|
||||
file: String,
|
||||
function: String,
|
||||
@ -68,6 +75,13 @@ public protocol DB {
|
||||
block: (DBWriteTransaction) throws(E) -> T
|
||||
) throws(E) -> T
|
||||
|
||||
func writeWithRollbackIfThrows<T, E>(
|
||||
file: String,
|
||||
function: String,
|
||||
line: Int,
|
||||
block: (DBWriteTransaction) throws(E) -> T
|
||||
) throws(E) -> T
|
||||
|
||||
func writeWithTxCompletion<T>(
|
||||
file: String,
|
||||
function: String,
|
||||
@ -142,6 +156,15 @@ extension DB {
|
||||
return try await awaitableWrite(file: file, function: function, line: line, block: block)
|
||||
}
|
||||
|
||||
public func awaitableWriteWithRollbackIfThrows<T, E>(
|
||||
file: String = #file,
|
||||
function: String = #function,
|
||||
line: Int = #line,
|
||||
block: (DBWriteTransaction) throws(E) -> T
|
||||
) async throws(E) -> T {
|
||||
return try await awaitableWriteWithRollbackIfThrows(file: file, function: function, line: line, block: block)
|
||||
}
|
||||
|
||||
public func awaitableWriteWithTxCompletion<T>(
|
||||
file: String = #file,
|
||||
function: String = #function,
|
||||
@ -171,6 +194,15 @@ extension DB {
|
||||
return try write(file: file, function: function, line: line, block: block)
|
||||
}
|
||||
|
||||
public func writeWithRollbackIfThrows<T, E>(
|
||||
file: String = #file,
|
||||
function: String = #function,
|
||||
line: Int = #line,
|
||||
block: (DBWriteTransaction) throws(E) -> T
|
||||
) throws(E) -> T {
|
||||
return try writeWithRollbackIfThrows(file: file, function: function, line: line, block: block)
|
||||
}
|
||||
|
||||
public func writeWithTxCompletion<T>(
|
||||
file: String = #file,
|
||||
function: String = #function,
|
||||
|
||||
@ -54,6 +54,62 @@ class SDSDatabaseStorageRollbackTest: SSKBaseTest {
|
||||
}
|
||||
}
|
||||
|
||||
func test_writeWithRollbackIfThrows() {
|
||||
try? databaseStorage.writeWithRollbackIfThrows { tx in
|
||||
XCTAssertFalse(kvStore.getBool(key, defaultValue: false, transaction: tx))
|
||||
kvStore.setBool(true, key: key, transaction: tx)
|
||||
throw SomeError()
|
||||
}
|
||||
|
||||
// Should have rolled back.
|
||||
databaseStorage.read { tx in
|
||||
XCTAssertFalse(kvStore.getBool(key, defaultValue: false, transaction: tx))
|
||||
}
|
||||
|
||||
// Run it again but catch the throw this time.
|
||||
databaseStorage.writeWithRollbackIfThrows { tx in
|
||||
do {
|
||||
kvStore.setBool(true, key: key, transaction: tx)
|
||||
throw SomeError()
|
||||
} catch {
|
||||
// Suppress error
|
||||
}
|
||||
}
|
||||
|
||||
// Should NOT have rolled back.
|
||||
databaseStorage.read { tx in
|
||||
XCTAssertTrue(kvStore.getBool(key, defaultValue: false, transaction: tx))
|
||||
}
|
||||
}
|
||||
|
||||
func test_writeWithRollbackIfThrows_async() async {
|
||||
try? await databaseStorage.awaitableWriteWithRollbackIfThrows { tx in
|
||||
XCTAssertFalse(kvStore.getBool(key, defaultValue: false, transaction: tx))
|
||||
kvStore.setBool(true, key: key, transaction: tx)
|
||||
throw SomeError()
|
||||
}
|
||||
|
||||
// Should have rolled back.
|
||||
databaseStorage.read { tx in
|
||||
XCTAssertFalse(kvStore.getBool(key, defaultValue: false, transaction: tx))
|
||||
}
|
||||
|
||||
// Run it again but catch the throw this time.
|
||||
await databaseStorage.awaitableWriteWithRollbackIfThrows { tx in
|
||||
do {
|
||||
kvStore.setBool(true, key: key, transaction: tx)
|
||||
throw SomeError()
|
||||
} catch {
|
||||
// Suppress error
|
||||
}
|
||||
}
|
||||
|
||||
// Should NOT have rolled back.
|
||||
databaseStorage.read { tx in
|
||||
XCTAssertTrue(kvStore.getBool(key, defaultValue: false, transaction: tx))
|
||||
}
|
||||
}
|
||||
|
||||
func test_writeWithTxCompletionRollback() {
|
||||
databaseStorage.writeWithTxCompletion { tx in
|
||||
XCTAssertFalse(kvStore.getBool(key, defaultValue: false, transaction: tx))
|
||||
@ -165,6 +221,62 @@ class InMemoryDBRollbackTest: XCTestCase {
|
||||
}
|
||||
}
|
||||
|
||||
func test_writeWithRollbackIfThrows() {
|
||||
try? db.writeWithRollbackIfThrows { tx in
|
||||
XCTAssertFalse(getBool(tx: tx))
|
||||
setBool(true, tx: tx)
|
||||
throw SomeError()
|
||||
}
|
||||
|
||||
// Should have rolled back.
|
||||
db.read { tx in
|
||||
XCTAssertFalse(getBool(tx: tx))
|
||||
}
|
||||
|
||||
// Run it again but catch the throw this time.
|
||||
db.writeWithRollbackIfThrows { tx in
|
||||
do {
|
||||
setBool(true, tx: tx)
|
||||
throw SomeError()
|
||||
} catch {
|
||||
// Suppress error
|
||||
}
|
||||
}
|
||||
|
||||
// Should NOT have rolled back.
|
||||
db.read { tx in
|
||||
XCTAssertTrue(getBool(tx: tx))
|
||||
}
|
||||
}
|
||||
|
||||
func test_writeWithRollbackIfThrows_async() async {
|
||||
try? await db.awaitableWriteWithRollbackIfThrows { tx in
|
||||
XCTAssertFalse(getBool(tx: tx))
|
||||
setBool(true, tx: tx)
|
||||
throw SomeError()
|
||||
}
|
||||
|
||||
// Should have rolled back.
|
||||
db.read { tx in
|
||||
XCTAssertFalse(getBool(tx: tx))
|
||||
}
|
||||
|
||||
// Run it again but catch the throw this time.
|
||||
await db.awaitableWriteWithRollbackIfThrows { tx in
|
||||
do {
|
||||
setBool(true, tx: tx)
|
||||
throw SomeError()
|
||||
} catch {
|
||||
// Suppress error
|
||||
}
|
||||
}
|
||||
|
||||
// Should NOT have rolled back.
|
||||
db.read { tx in
|
||||
XCTAssertTrue(getBool(tx: tx))
|
||||
}
|
||||
}
|
||||
|
||||
func test_writeWithTxCompletionRollback() {
|
||||
db.writeWithTxCompletion { tx in
|
||||
XCTAssertFalse(getBool(tx: tx))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user