diff --git a/SignalServiceKit/Concurrency/CooperativeTimeout.swift b/SignalServiceKit/Concurrency/CooperativeTimeout.swift index d4bf413e2c..7cdfd5aac3 100644 --- a/SignalServiceKit/Concurrency/CooperativeTimeout.swift +++ b/SignalServiceKit/Concurrency/CooperativeTimeout.swift @@ -19,11 +19,11 @@ public struct CooperativeTimeoutError: Error {} /// value from invoking `operation`, the error thrown when invoking /// `operation`, or the `CooperativeTimeoutError` thrown after `seconds`. public func withCooperativeTimeout(seconds: TimeInterval, operation: @escaping () async throws -> T) async throws -> T { - return try await withThrowingTaskGroup(of: T?.self) { taskGroup in - taskGroup.addTask { - return try await operation() - } - taskGroup.addTask { + let results = await _withCooperativeRace(operations: [ + { () async throws -> T? in + try await operation() + }, + { () async throws -> T? in do { try await Task.sleep(nanoseconds: seconds.clampedNanoseconds) } catch { @@ -34,21 +34,54 @@ public func withCooperativeTimeout(seconds: TimeInterval, operation: @escapin return nil } throw CooperativeTimeoutError() - } - if let result = try await taskGroup.next()! { - // If the first child Task to finish throws an error, that error will be - // rethrown on the prior line. When `withThrowingTaskGroup` throws an error - // from its body, it cancels all the other child Tasks. If, however, the - // first child Task to finish doesn't throw an error, and if it's the child - // Task that's invoking operation(), we must cancel the other one to avoid - // waiting for it to time out. - taskGroup.cancelAll() - return result - } else { - // If the first result is nil, it means the cooperative timeout child Task - // was canceled. We must to wait for the other task to finish -- the one - // invoking operation() -- and pass through its result to the caller. - return try await taskGroup.next()!! + }, + ]) + for result in results { + if let operationResult = try result.get() { + return operationResult } } + // There are always two results. If at least one of them throws an Error, + // that error will be re-thrown above, and we can't reach this code. If + // neither of them throws an Error, the result from invoking operation() + // will be nonnil, it will be returned above, and we can't reach this code. + owsFail("Can't reach this code.") +} + +/// Invokes `operation` & `operations`, passing through the earliest result. +/// +/// This method doesn't return until `operation` and every element of +/// `operations` returns. In other words, every operation must cooperate +/// with the cancellation request. +/// +/// The `operation` and `operations` parameters are separated in the +/// function signature to require callers to provide at least one operation. +/// This provides compile-time safety for the `.first!` forced unwrap. +public func withCooperativeRace( + _ operation: @escaping () async throws -> T, + _ operations: (() async throws -> T)..., +) async throws -> T { + return try await _withCooperativeRace(operations: [operation] + operations).first!.get() +} + +private func _withCooperativeRace(operations: [() async throws -> T]) async -> [Result] { + return await withThrowingTaskGroup { taskGroup in + for operation in operations { + taskGroup.addTask { + return try await operation() + } + } + var results = [Result]() + if let firstResult = await taskGroup.nextResult() { + results.append(firstResult) + // Cancel everything else as soon as anything wins the race. + taskGroup.cancelAll() + // This is cooperative, so even though we canceled all the other + // operations, they may still produce meaningful results. + while let otherResult = await taskGroup.nextResult() { + results.append(otherResult) + } + } + return results + } }