Add withCooperativeRace

This commit is contained in:
Max Radermacher 2025-05-30 02:13:35 -05:00 committed by GitHub
parent 6123108bd7
commit cb29b883f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -19,11 +19,11 @@ public struct CooperativeTimeoutError: Error {}
/// value from invoking `operation`, the error thrown when invoking
/// `operation`, or the `CooperativeTimeoutError` thrown after `seconds`.
public func withCooperativeTimeout<T>(seconds: TimeInterval, operation: @escaping () async throws -> T) async throws -> T {
return try await withThrowingTaskGroup(of: T?.self) { taskGroup in
taskGroup.addTask {
return try await operation()
}
taskGroup.addTask {
let results = await _withCooperativeRace(operations: [
{ () async throws -> T? in
try await operation()
},
{ () async throws -> T? in
do {
try await Task.sleep(nanoseconds: seconds.clampedNanoseconds)
} catch {
@ -34,21 +34,54 @@ public func withCooperativeTimeout<T>(seconds: TimeInterval, operation: @escapin
return nil
}
throw CooperativeTimeoutError()
}
if let result = try await taskGroup.next()! {
// If the first child Task to finish throws an error, that error will be
// rethrown on the prior line. When `withThrowingTaskGroup` throws an error
// from its body, it cancels all the other child Tasks. If, however, the
// first child Task to finish doesn't throw an error, and if it's the child
// Task that's invoking operation(), we must cancel the other one to avoid
// waiting for it to time out.
taskGroup.cancelAll()
return result
} else {
// If the first result is nil, it means the cooperative timeout child Task
// was canceled. We must to wait for the other task to finish -- the one
// invoking operation() -- and pass through its result to the caller.
return try await taskGroup.next()!!
},
])
for result in results {
if let operationResult = try result.get() {
return operationResult
}
}
// There are always two results. If at least one of them throws an Error,
// that error will be re-thrown above, and we can't reach this code. If
// neither of them throws an Error, the result from invoking operation()
// will be nonnil, it will be returned above, and we can't reach this code.
owsFail("Can't reach this code.")
}
/// Invokes `operation` & `operations`, passing through the earliest result.
///
/// This method doesn't return until `operation` and every element of
/// `operations` returns. In other words, every operation must cooperate
/// with the cancellation request.
///
/// The `operation` and `operations` parameters are separated in the
/// function signature to require callers to provide at least one operation.
/// This provides compile-time safety for the `.first!` forced unwrap.
public func withCooperativeRace<T>(
_ operation: @escaping () async throws -> T,
_ operations: (() async throws -> T)...,
) async throws -> T {
return try await _withCooperativeRace(operations: [operation] + operations).first!.get()
}
private func _withCooperativeRace<T>(operations: [() async throws -> T]) async -> [Result<T, any Error>] {
return await withThrowingTaskGroup { taskGroup in
for operation in operations {
taskGroup.addTask {
return try await operation()
}
}
var results = [Result<T, any Error>]()
if let firstResult = await taskGroup.nextResult() {
results.append(firstResult)
// Cancel everything else as soon as anything wins the race.
taskGroup.cancelAll()
// This is cooperative, so even though we canceled all the other
// operations, they may still produce meaningful results.
while let otherResult = await taskGroup.nextResult() {
results.append(otherResult)
}
}
return results
}
}