Structure concurrency for OWSUrlSession’s progress
This commit is contained in:
parent
0213818227
commit
0758c26868
@ -705,6 +705,7 @@
|
||||
504861AB2EEB4D1500B13C49 /* SignalAttachment.swift in Sources */ = {isa = PBXBuildFile; fileRef = 34D913491F62D4A500722898 /* SignalAttachment.swift */; };
|
||||
5049246F2F209DA8006469B3 /* OutgoingMessageRequestResponseSyncMessage.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5049246E2F209DA8006469B3 /* OutgoingMessageRequestResponseSyncMessage.swift */; };
|
||||
5049FA2F28BEAABE00D6E099 /* ContactDiscoveryV2Operation.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5049FA2D28BEAABE00D6E099 /* ContactDiscoveryV2Operation.swift */; };
|
||||
504B8BAC2F802FA2002B8AB9 /* DeferredContinuation.swift in Sources */ = {isa = PBXBuildFile; fileRef = 504B8BAB2F802FA2002B8AB9 /* DeferredContinuation.swift */; };
|
||||
504F397C29D23B1700E849A6 /* ValidatedIncomingEnvelope.swift in Sources */ = {isa = PBXBuildFile; fileRef = 504F397B29D23B1700E849A6 /* ValidatedIncomingEnvelope.swift */; };
|
||||
504F98B12EAFFAC600DF465B /* KyberPreKeyUseRecord.swift in Sources */ = {isa = PBXBuildFile; fileRef = 504F98B02EAFFAC600DF465B /* KyberPreKeyUseRecord.swift */; };
|
||||
504F98B32EB0270A00DF465B /* SendMessageFailure.swift in Sources */ = {isa = PBXBuildFile; fileRef = 504F98B22EB0270A00DF465B /* SendMessageFailure.swift */; };
|
||||
@ -4974,6 +4975,7 @@
|
||||
50468F2A29EE19C300948E02 /* PhoneNumberChangedMessageInserterTest.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PhoneNumberChangedMessageInserterTest.swift; sourceTree = "<group>"; };
|
||||
5049246E2F209DA8006469B3 /* OutgoingMessageRequestResponseSyncMessage.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OutgoingMessageRequestResponseSyncMessage.swift; sourceTree = "<group>"; };
|
||||
5049FA2D28BEAABE00D6E099 /* ContactDiscoveryV2Operation.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ContactDiscoveryV2Operation.swift; sourceTree = "<group>"; };
|
||||
504B8BAB2F802FA2002B8AB9 /* DeferredContinuation.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = DeferredContinuation.swift; sourceTree = "<group>"; };
|
||||
504F397B29D23B1700E849A6 /* ValidatedIncomingEnvelope.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ValidatedIncomingEnvelope.swift; sourceTree = "<group>"; };
|
||||
504F98B02EAFFAC600DF465B /* KyberPreKeyUseRecord.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = KyberPreKeyUseRecord.swift; sourceTree = "<group>"; };
|
||||
504F98B22EB0270A00DF465B /* SendMessageFailure.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SendMessageFailure.swift; sourceTree = "<group>"; };
|
||||
@ -9634,6 +9636,7 @@
|
||||
50C0203B2CA4A61E00BDC4EF /* ConcurrentTaskQueue.swift */,
|
||||
500AF3AE2C58366700CB9F4F /* CooperativeTimeout.swift */,
|
||||
500AF3B02C58385600CB9F4F /* CooperativeTimeoutTest.swift */,
|
||||
504B8BAB2F802FA2002B8AB9 /* DeferredContinuation.swift */,
|
||||
505C2BC12E85D72B0009237F /* KeyedConcurrentTaskQueue.swift */,
|
||||
50B0BCA22DFA17BE0076B680 /* Monitor.swift */,
|
||||
50C0203D2CA4A7A500BDC4EF /* Retry.swift */,
|
||||
@ -19225,6 +19228,7 @@
|
||||
C14EC1AB2BAB57B900A4D064 /* DecryptingStreamTransform.swift in Sources */,
|
||||
F9C5CD27289453B300548EEE /* DeepCopy.swift in Sources */,
|
||||
F9C5CC11289453B300548EEE /* DefaultStickers.swift in Sources */,
|
||||
504B8BAC2F802FA2002B8AB9 /* DeferredContinuation.swift in Sources */,
|
||||
D91AC9322B61AD9A00814975 /* DeletedCallRecord.swift in Sources */,
|
||||
D9CA8AB02B698DFF00787167 /* DeletedCallRecordExpirationJob.swift in Sources */,
|
||||
D91AC9342B61C1F000814975 /* DeletedCallRecordStore.swift in Sources */,
|
||||
|
||||
@ -416,7 +416,6 @@ private enum DebugLogUploader {
|
||||
mimeType: mimeType,
|
||||
textParts: textParts,
|
||||
ignoreAppExpiry: true,
|
||||
progress: nil,
|
||||
)
|
||||
|
||||
let statusCode = response.responseStatusCode
|
||||
|
||||
@ -11,20 +11,13 @@ import Foundation
|
||||
/// example, when waiting for an event to occur, "cancellation" means "stop
|
||||
/// waiting for the event to occur" and not "stop the event from occurring".
|
||||
public struct CancellableContinuation<T>: Sendable {
|
||||
private enum State {
|
||||
case initial
|
||||
case waiting(CheckedContinuation<T, Error>)
|
||||
case completed(Result<T, Error>)
|
||||
case consumed
|
||||
}
|
||||
|
||||
private let state = AtomicValue<State>(State.initial, lock: .init())
|
||||
private let deferredContinuation = DeferredContinuation<T>()
|
||||
|
||||
public init() {
|
||||
}
|
||||
|
||||
func cancel() {
|
||||
self.resume(with: .failure(CancellationError()))
|
||||
self.deferredContinuation.resume(with: .failure(CancellationError()))
|
||||
}
|
||||
|
||||
/// Resumes the continuation with `result`.
|
||||
@ -32,49 +25,13 @@ public struct CancellableContinuation<T>: Sendable {
|
||||
/// It's safe (and harmless) to call `resume` multiple times; redundant
|
||||
/// invocations are ignored.
|
||||
public func resume(with result: Result<T, Error>) {
|
||||
let continuation = self.state.update { state -> CheckedContinuation<T, Error>? in
|
||||
switch state {
|
||||
case .initial:
|
||||
state = .completed(result)
|
||||
return nil
|
||||
case .waiting(let continuation):
|
||||
state = .consumed
|
||||
return continuation
|
||||
case .completed(_), .consumed:
|
||||
// Ignore the new result.
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if let continuation {
|
||||
continuation.resume(with: result)
|
||||
}
|
||||
self.deferredContinuation.resume(with: result)
|
||||
}
|
||||
|
||||
/// Waits for the result. Should only be called once per instance!
|
||||
public func wait() async throws -> T {
|
||||
try await withTaskCancellationHandler(
|
||||
operation: {
|
||||
return try await withCheckedThrowingContinuation { continuation in
|
||||
let result = self.state.update { state -> Result<T, Error>? in
|
||||
switch state {
|
||||
case .initial:
|
||||
state = .waiting(continuation)
|
||||
return nil
|
||||
case .completed(let result):
|
||||
state = .consumed
|
||||
return result
|
||||
case .waiting(_), .consumed:
|
||||
continuation.resume(throwing: OWSAssertionError(
|
||||
"should not await a CancellableContinuation multiple times",
|
||||
))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if let result {
|
||||
continuation.resume(with: result)
|
||||
}
|
||||
}
|
||||
},
|
||||
operation: { try await self.deferredContinuation.wait() },
|
||||
onCancel: { self.cancel() },
|
||||
)
|
||||
}
|
||||
|
||||
68
SignalServiceKit/Concurrency/DeferredContinuation.swift
Normal file
68
SignalServiceKit/Concurrency/DeferredContinuation.swift
Normal file
@ -0,0 +1,68 @@
|
||||
//
|
||||
// Copyright 2026 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
import Foundation
|
||||
|
||||
/// A container that stores a Result before creating a continuation.
|
||||
public struct DeferredContinuation<T>: Sendable {
|
||||
private enum State {
|
||||
case initial
|
||||
case waiting(CheckedContinuation<T, Error>)
|
||||
case completed(Result<T, Error>)
|
||||
case consumed
|
||||
}
|
||||
|
||||
private let state = AtomicValue<State>(State.initial, lock: .init())
|
||||
|
||||
public init() {
|
||||
}
|
||||
|
||||
/// Resumes the continuation with `result`.
|
||||
///
|
||||
/// It's safe (and harmless) to call `resume` multiple times; redundant
|
||||
/// invocations are ignored.
|
||||
public func resume(with result: Result<T, Error>) {
|
||||
let continuation = self.state.update { state -> CheckedContinuation<T, Error>? in
|
||||
switch state {
|
||||
case .initial:
|
||||
state = .completed(result)
|
||||
return nil
|
||||
case .waiting(let continuation):
|
||||
state = .consumed
|
||||
return continuation
|
||||
case .completed(_), .consumed:
|
||||
// Ignore the new result.
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if let continuation {
|
||||
continuation.resume(with: result)
|
||||
}
|
||||
}
|
||||
|
||||
/// Waits for the result. Should only be called once per instance!
|
||||
public func wait() async throws -> T {
|
||||
return try await withCheckedThrowingContinuation { continuation in
|
||||
let result = self.state.update { state -> Result<T, Error>? in
|
||||
switch state {
|
||||
case .initial:
|
||||
state = .waiting(continuation)
|
||||
return nil
|
||||
case .completed(let result):
|
||||
state = .consumed
|
||||
return result
|
||||
case .waiting(_), .consumed:
|
||||
continuation.resume(throwing: OWSAssertionError(
|
||||
"should not await a DeferredContinuation multiple times",
|
||||
))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if let result {
|
||||
continuation.resume(with: result)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1731,7 +1731,7 @@ public class AttachmentDownloadManagerImpl: AttachmentDownloadManager {
|
||||
return try await urlSession.performDownload(
|
||||
requestUrl: requestUrl,
|
||||
resumeData: resumeData,
|
||||
progress: wrappedProgressSource,
|
||||
progressBlock: wrappedProgressSource.asProgressBlock(),
|
||||
)
|
||||
}
|
||||
downloadResponse = try await downloadTask!.value
|
||||
@ -1741,7 +1741,7 @@ public class AttachmentDownloadManagerImpl: AttachmentDownloadManager {
|
||||
urlPath,
|
||||
method: .get,
|
||||
headers: headers,
|
||||
progress: wrappedProgressSource,
|
||||
progressBlock: wrappedProgressSource.asProgressBlock(),
|
||||
)
|
||||
}
|
||||
downloadResponse = try await downloadTask!.value
|
||||
|
||||
@ -99,7 +99,7 @@ public class BaseOWSURLSessionMock: OWSURLSessionProtocol {
|
||||
public func performUpload(
|
||||
request: URLRequest,
|
||||
requestData: Data,
|
||||
progress: OWSProgressSource?,
|
||||
progressBlock: OWSURLSession.ProgressBlock,
|
||||
) async throws -> HTTPResponse {
|
||||
// Want different behavior? Write a custom mock class
|
||||
return HTTPResponse(
|
||||
@ -114,7 +114,7 @@ public class BaseOWSURLSessionMock: OWSURLSessionProtocol {
|
||||
request: URLRequest,
|
||||
fileUrl: URL,
|
||||
ignoreAppExpiry: Bool,
|
||||
progress: OWSProgressSource?,
|
||||
progressBlock: OWSURLSession.ProgressBlock,
|
||||
) async throws -> HTTPResponse {
|
||||
// Want different behavior? Write a custom mock class
|
||||
return HTTPResponse(
|
||||
@ -137,7 +137,7 @@ public class BaseOWSURLSessionMock: OWSURLSessionProtocol {
|
||||
|
||||
public func performDownload(
|
||||
request: URLRequest,
|
||||
progress: OWSProgressSource?,
|
||||
progressBlock: OWSURLSession.ProgressBlock,
|
||||
) async throws -> OWSUrlDownloadResponse {
|
||||
// Want different behavior? Write a custom mock class
|
||||
return OWSUrlDownloadResponse(
|
||||
@ -149,7 +149,7 @@ public class BaseOWSURLSessionMock: OWSURLSessionProtocol {
|
||||
public func performDownload(
|
||||
requestUrl: URL,
|
||||
resumeData: Data,
|
||||
progress: OWSProgressSource?,
|
||||
progressBlock: OWSURLSession.ProgressBlock,
|
||||
) async throws -> OWSUrlDownloadResponse {
|
||||
// Want different behavior? Write a custom mock class
|
||||
return OWSUrlDownloadResponse(
|
||||
|
||||
@ -127,14 +127,14 @@ public protocol OWSURLSessionProtocol: AnyObject {
|
||||
func performUpload(
|
||||
request: URLRequest,
|
||||
requestData: Data,
|
||||
progress: OWSProgressSource?,
|
||||
progressBlock: OWSURLSession.ProgressBlock,
|
||||
) async throws -> HTTPResponse
|
||||
|
||||
func performUpload(
|
||||
request: URLRequest,
|
||||
fileUrl: URL,
|
||||
ignoreAppExpiry: Bool,
|
||||
progress: OWSProgressSource?,
|
||||
progressBlock: OWSURLSession.ProgressBlock,
|
||||
) async throws -> HTTPResponse
|
||||
|
||||
func performRequest(
|
||||
@ -145,12 +145,12 @@ public protocol OWSURLSessionProtocol: AnyObject {
|
||||
func performDownload(
|
||||
requestUrl: URL,
|
||||
resumeData: Data,
|
||||
progress: OWSProgressSource?,
|
||||
progressBlock: OWSURLSession.ProgressBlock,
|
||||
) async throws -> OWSUrlDownloadResponse
|
||||
|
||||
func performDownload(
|
||||
request: URLRequest,
|
||||
progress: OWSProgressSource?,
|
||||
progressBlock: OWSURLSession.ProgressBlock,
|
||||
) async throws -> OWSUrlDownloadResponse
|
||||
|
||||
func webSocketTask(
|
||||
@ -188,10 +188,10 @@ public extension OWSURLSessionProtocol {
|
||||
method: HTTPMethod,
|
||||
headers: HttpHeaders = HttpHeaders(),
|
||||
requestData: Data,
|
||||
progress: OWSProgressSource? = nil,
|
||||
progressBlock: OWSURLSession.ProgressBlock = { _, _ in },
|
||||
) async throws -> HTTPResponse {
|
||||
let request = try self.endpoint.buildRequest(urlString, method: method, headers: headers, body: requestData)
|
||||
return try await self.performUpload(request: request, requestData: requestData, progress: progress)
|
||||
return try await self.performUpload(request: request, requestData: requestData, progressBlock: progressBlock)
|
||||
}
|
||||
|
||||
func performUpload(
|
||||
@ -199,14 +199,14 @@ public extension OWSURLSessionProtocol {
|
||||
method: HTTPMethod,
|
||||
headers: HttpHeaders = HttpHeaders(),
|
||||
fileUrl: URL,
|
||||
progress: OWSProgressSource? = nil,
|
||||
progressBlock: OWSURLSession.ProgressBlock = { _, _ in },
|
||||
) async throws -> HTTPResponse {
|
||||
let request = try self.endpoint.buildRequest(urlString, method: method, headers: headers)
|
||||
return try await self.performUpload(
|
||||
request: request,
|
||||
fileUrl: fileUrl,
|
||||
ignoreAppExpiry: false,
|
||||
progress: progress,
|
||||
progressBlock: progressBlock,
|
||||
)
|
||||
}
|
||||
|
||||
@ -230,10 +230,10 @@ public extension OWSURLSessionProtocol {
|
||||
method: HTTPMethod,
|
||||
headers: HttpHeaders = HttpHeaders(),
|
||||
body: Data? = nil,
|
||||
progress: OWSProgressSource? = nil,
|
||||
progressBlock: OWSURLSession.ProgressBlock = { _, _ in },
|
||||
) async throws -> OWSUrlDownloadResponse {
|
||||
let request = try self.endpoint.buildRequest(urlString, method: method, headers: headers, body: body)
|
||||
return try await self.performDownload(request: request, progress: progress)
|
||||
return try await self.performDownload(request: request, progressBlock: progressBlock)
|
||||
}
|
||||
}
|
||||
|
||||
@ -249,7 +249,7 @@ extension OWSURLSessionProtocol {
|
||||
mimeType: String,
|
||||
textParts textPartsDictionary: OrderedDictionary<String, String>,
|
||||
ignoreAppExpiry: Bool = false,
|
||||
progress: OWSProgressSource? = nil,
|
||||
progressBlock: OWSURLSession.ProgressBlock = { _, _ in },
|
||||
) async throws -> HTTPResponse {
|
||||
let multipartBodyFileURL = OWSFileSystem.temporaryFileUrl(
|
||||
fileExtension: nil,
|
||||
@ -289,7 +289,7 @@ extension OWSURLSessionProtocol {
|
||||
request: request,
|
||||
fileUrl: multipartBodyFileURL,
|
||||
ignoreAppExpiry: ignoreAppExpiry,
|
||||
progress: progress,
|
||||
progressBlock: progressBlock,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -11,6 +11,8 @@ public enum OWSURLSessionError: Error {
|
||||
|
||||
public class OWSURLSession: OWSURLSessionProtocol {
|
||||
|
||||
public typealias ProgressBlock = (_ completedByteCount: Int64, _ totalByteCount: Int64) async -> Void
|
||||
|
||||
// MARK: - OWSURLSessionProtocol conformance
|
||||
|
||||
public let endpoint: OWSURLSessionEndpoint
|
||||
@ -149,12 +151,12 @@ public class OWSURLSession: OWSURLSessionProtocol {
|
||||
public func performUpload(
|
||||
request: URLRequest,
|
||||
requestData: Data,
|
||||
progress: OWSProgressSource?,
|
||||
progressBlock: ProgressBlock,
|
||||
) async throws -> HTTPResponse {
|
||||
return try await performUpload(
|
||||
request: request,
|
||||
ignoreAppExpiry: false,
|
||||
progress: progress,
|
||||
progressBlock: progressBlock,
|
||||
taskBlock: { self.session.uploadTask(with: request, from: requestData) },
|
||||
)
|
||||
}
|
||||
@ -163,12 +165,12 @@ public class OWSURLSession: OWSURLSessionProtocol {
|
||||
request: URLRequest,
|
||||
fileUrl: URL,
|
||||
ignoreAppExpiry: Bool,
|
||||
progress: OWSProgressSource?,
|
||||
progressBlock: ProgressBlock,
|
||||
) async throws -> HTTPResponse {
|
||||
return try await performUpload(
|
||||
request: request,
|
||||
ignoreAppExpiry: ignoreAppExpiry,
|
||||
progress: progress,
|
||||
progressBlock: progressBlock,
|
||||
taskBlock: { self.session.uploadTask(with: request, fromFile: fileUrl) },
|
||||
)
|
||||
}
|
||||
@ -182,9 +184,11 @@ public class OWSURLSession: OWSURLSessionProtocol {
|
||||
let requestConfig = self.requestConfig(requestUrl: request.url!)
|
||||
let task = session.dataTask(with: request)
|
||||
|
||||
let (urlResponse, responseData) = try await runTask(task, taskState: {
|
||||
return DataTaskState(progressSource: nil, completion: $0)
|
||||
})
|
||||
let (urlResponse, responseData) = try await runTask(
|
||||
task,
|
||||
taskState: { DataTaskState(progress: $0, completion: $1) },
|
||||
progressBlock: { _, _ in },
|
||||
)
|
||||
|
||||
return try await handleDataResult(
|
||||
urlResponse: urlResponse,
|
||||
@ -196,13 +200,13 @@ public class OWSURLSession: OWSURLSessionProtocol {
|
||||
|
||||
public func performDownload(
|
||||
request: URLRequest,
|
||||
progress: OWSProgressSource?,
|
||||
progressBlock: ProgressBlock,
|
||||
) async throws -> OWSUrlDownloadResponse {
|
||||
let request = prepareRequest(request: request)
|
||||
guard let requestUrl = request.url else {
|
||||
throw OWSAssertionError("Request missing url.")
|
||||
}
|
||||
return try await performDownload(requestUrl: requestUrl, progress: progress) {
|
||||
return try await performDownload(requestUrl: requestUrl, progressBlock: progressBlock) {
|
||||
// Don't use a completion block or the delegate will be ignored for download tasks.
|
||||
return self.session.downloadTask(with: request)
|
||||
}
|
||||
@ -211,9 +215,9 @@ public class OWSURLSession: OWSURLSessionProtocol {
|
||||
public func performDownload(
|
||||
requestUrl: URL,
|
||||
resumeData: Data,
|
||||
progress: OWSProgressSource?,
|
||||
progressBlock: ProgressBlock,
|
||||
) async throws -> OWSUrlDownloadResponse {
|
||||
return try await performDownload(requestUrl: requestUrl, progress: progress) {
|
||||
return try await performDownload(requestUrl: requestUrl, progressBlock: progressBlock) {
|
||||
// Don't use a completion block or the delegate will be ignored for download tasks.
|
||||
return self.session.downloadTask(withResumeData: resumeData)
|
||||
}
|
||||
@ -452,7 +456,7 @@ public class OWSURLSession: OWSURLSessionProtocol {
|
||||
|
||||
do {
|
||||
rawRequest.logger.info("Sending… -> \(rawRequest)")
|
||||
let response = try await performUpload(request: request, requestData: requestBody, progress: nil)
|
||||
let response = try await performUpload(request: request, requestData: requestBody, progressBlock: { _, _ in })
|
||||
rawRequest.logger.info("HTTP \(response.responseStatusCode) <- \(rawRequest)")
|
||||
return response
|
||||
} catch where error.httpStatusCode != nil {
|
||||
@ -467,7 +471,7 @@ public class OWSURLSession: OWSURLSessionProtocol {
|
||||
private func performUpload(
|
||||
request: URLRequest,
|
||||
ignoreAppExpiry: Bool,
|
||||
progress: OWSProgressSource?,
|
||||
progressBlock: ProgressBlock,
|
||||
taskBlock: () -> URLSessionUploadTask,
|
||||
) async throws -> HTTPResponse {
|
||||
if !ignoreAppExpiry, DependenciesBridge.shared.appExpiry.isExpired(now: Date()) {
|
||||
@ -481,9 +485,11 @@ public class OWSURLSession: OWSURLSessionProtocol {
|
||||
let urlResponse: URLResponse?
|
||||
let responseData: Data
|
||||
do {
|
||||
(urlResponse, responseData) = try await runTask(task, taskState: {
|
||||
return DataTaskState(progressSource: progress, completion: $0)
|
||||
})
|
||||
(urlResponse, responseData) = try await runTask(
|
||||
task,
|
||||
taskState: { DataTaskState(progress: $0, completion: $1) },
|
||||
progressBlock: progressBlock,
|
||||
)
|
||||
} catch {
|
||||
throw handleError(error, originalRequest: task.originalRequest, requestConfig: requestConfig)
|
||||
}
|
||||
@ -497,7 +503,7 @@ public class OWSURLSession: OWSURLSessionProtocol {
|
||||
|
||||
private func performDownload(
|
||||
requestUrl: URL,
|
||||
progress: OWSProgressSource?,
|
||||
progressBlock: ProgressBlock,
|
||||
taskBlock: () -> URLSessionDownloadTask,
|
||||
) async throws -> OWSUrlDownloadResponse {
|
||||
let appExpiry = DependenciesBridge.shared.appExpiry
|
||||
@ -508,9 +514,11 @@ public class OWSURLSession: OWSURLSessionProtocol {
|
||||
let requestConfig = self.requestConfig(requestUrl: requestUrl)
|
||||
let task = taskBlock()
|
||||
|
||||
let (urlResponse, downloadUrl) = try await runTask(task, taskState: {
|
||||
return DownloadTaskState(progressSource: progress, completion: $0)
|
||||
})
|
||||
let (urlResponse, downloadUrl) = try await runTask(
|
||||
task,
|
||||
taskState: { DownloadTaskState(progress: $0, completion: $1) },
|
||||
progressBlock: progressBlock,
|
||||
)
|
||||
|
||||
return try await handleDownloadResult(
|
||||
urlResponse: urlResponse,
|
||||
@ -520,7 +528,11 @@ public class OWSURLSession: OWSURLSessionProtocol {
|
||||
)
|
||||
}
|
||||
|
||||
private func runTask<T>(_ task: URLSessionTask, taskState: (CheckedContinuation<T, any Error>) -> some TaskState) async throws -> T {
|
||||
private func runTask<T>(
|
||||
_ task: URLSessionTask,
|
||||
taskState: (TaskState.ProgressContinuation, DeferredContinuation<T>) -> some TaskState,
|
||||
progressBlock: ProgressBlock,
|
||||
) async throws -> T {
|
||||
// It's possible for operation and onCancel to race one another, so we use
|
||||
// a counter to ensure that cancellation happens after addTask is invoked.
|
||||
// (You can trigger this by sending a request from a canceled Task.)
|
||||
@ -528,8 +540,9 @@ public class OWSURLSession: OWSURLSessionProtocol {
|
||||
|
||||
return try await withTaskCancellationHandler(
|
||||
operation: {
|
||||
return try await withCheckedThrowingContinuation { continuation in
|
||||
self.addTask(task, taskState: taskState(continuation))
|
||||
let completion = DeferredContinuation<T>()
|
||||
let progressStream = AsyncStream(bufferingPolicy: .bufferingNewest(1)) { continuation in
|
||||
self.addTask(task, taskState: taskState(continuation, completion))
|
||||
// If cancel was already called, cancel it now.
|
||||
if cancelState.increment() == 2 {
|
||||
task.cancel()
|
||||
@ -537,6 +550,10 @@ public class OWSURLSession: OWSURLSessionProtocol {
|
||||
task.resume()
|
||||
}
|
||||
}
|
||||
for await progressUpdate in progressStream {
|
||||
await progressBlock(progressUpdate.completedByteCount, progressUpdate.totalByteCount)
|
||||
}
|
||||
return try await completion.wait()
|
||||
},
|
||||
onCancel: {
|
||||
// If the task was already added, cancel it now.
|
||||
@ -569,9 +586,9 @@ public class OWSURLSession: OWSURLSessionProtocol {
|
||||
}
|
||||
}
|
||||
|
||||
private func progressSource(forTask task: URLSessionTask) -> OWSProgressSource? {
|
||||
private func progress(forTask task: URLSessionTask) -> TaskState.ProgressContinuation? {
|
||||
return updateTaskStates {
|
||||
return $0[task.taskIdentifier]?.progressSource
|
||||
return $0[task.taskIdentifier]?.progress
|
||||
}
|
||||
}
|
||||
|
||||
@ -606,7 +623,8 @@ public class OWSURLSession: OWSURLSessionProtocol {
|
||||
owsFailDebug("Missing TaskState.")
|
||||
return
|
||||
}
|
||||
taskState.completion.resume(returning: (task.response, downloadUrl))
|
||||
taskState.progress?.finish()
|
||||
taskState.completion.resume(with: .success((task.response, downloadUrl)))
|
||||
}
|
||||
|
||||
private func dataTaskDidSucceed(_ task: URLSessionTask) {
|
||||
@ -615,7 +633,8 @@ public class OWSURLSession: OWSURLSessionProtocol {
|
||||
return
|
||||
}
|
||||
let responseData = taskState.pendingData.get()
|
||||
taskState.completion.resume(returning: (task.response, responseData))
|
||||
taskState.progress?.finish()
|
||||
taskState.completion.resume(with: .success((task.response, responseData)))
|
||||
}
|
||||
|
||||
private func taskDidFail(_ task: URLSessionTask, error: Error) {
|
||||
@ -703,13 +722,8 @@ extension OWSURLSession {
|
||||
}
|
||||
|
||||
func urlSession(_ session: URLSession, task: URLSessionTask, didSendBodyData bytesSent: Int64, totalBytesSent: Int64, totalBytesExpectedToSend: Int64) {
|
||||
guard let progressSource = self.progressSource(forTask: task) else {
|
||||
return
|
||||
}
|
||||
// TODO: We could check for NSURLSessionTransferSizeUnknown here.
|
||||
if progressSource.completedUnitCount < totalBytesSent {
|
||||
progressSource.incrementCompletedUnitCount(by: UInt64(totalBytesSent) - progressSource.completedUnitCount)
|
||||
}
|
||||
self.progress(forTask: task)?.yield((totalBytesSent, totalBytesExpectedToSend))
|
||||
}
|
||||
|
||||
func urlSession(_ session: URLSession, downloadTask: URLSessionDownloadTask, didFinishDownloadingTo location: URL) {
|
||||
@ -750,12 +764,7 @@ extension OWSURLSession {
|
||||
taskDidFail(downloadTask, error: OWSURLSessionError.responseTooLarge)
|
||||
return
|
||||
}
|
||||
guard let progressSource = self.progressSource(forTask: downloadTask) else {
|
||||
return
|
||||
}
|
||||
if progressSource.completedUnitCount < totalBytesWritten {
|
||||
progressSource.incrementCompletedUnitCount(by: UInt64(totalBytesWritten) - progressSource.completedUnitCount)
|
||||
}
|
||||
self.progress(forTask: downloadTask)?.yield((totalBytesWritten, totalBytesExpectedToWrite))
|
||||
}
|
||||
|
||||
func urlSession(
|
||||
@ -768,12 +777,7 @@ extension OWSURLSession {
|
||||
taskDidFail(downloadTask, error: OWSURLSessionError.responseTooLarge)
|
||||
return
|
||||
}
|
||||
guard let progressSource = self.progressSource(forTask: downloadTask) else {
|
||||
return
|
||||
}
|
||||
if progressSource.completedUnitCount < fileOffset {
|
||||
progressSource.incrementCompletedUnitCount(by: UInt64(fileOffset) - progressSource.completedUnitCount)
|
||||
}
|
||||
self.progress(forTask: downloadTask)?.yield((fileOffset, expectedTotalBytes))
|
||||
}
|
||||
|
||||
func urlSession(
|
||||
@ -811,25 +815,25 @@ extension OWSURLSession {
|
||||
// MARK: - TaskState
|
||||
|
||||
private protocol TaskState {
|
||||
typealias ProgressBlock = (URLSessionTask, Progress) -> Void
|
||||
var progressSource: OWSProgressSource? { get }
|
||||
typealias ProgressContinuation = AsyncStream<(completedByteCount: Int64, totalByteCount: Int64)>.Continuation
|
||||
var progress: ProgressContinuation? { get }
|
||||
func reject(error: any Error, task: URLSessionTask)
|
||||
}
|
||||
|
||||
// MARK: - DownloadTaskState
|
||||
|
||||
private class DownloadTaskState: TaskState {
|
||||
typealias CompletionContinuation = CheckedContinuation<(URLResponse?, URL), any Error>
|
||||
let progressSource: OWSProgressSource?
|
||||
let completion: CompletionContinuation
|
||||
let progress: ProgressContinuation?
|
||||
let completion: DeferredContinuation<(URLResponse?, URL)>
|
||||
|
||||
init(progressSource: OWSProgressSource?, completion: CompletionContinuation) {
|
||||
self.progressSource = progressSource
|
||||
init(progress: ProgressContinuation, completion: DeferredContinuation<(URLResponse?, URL)>) {
|
||||
self.progress = progress
|
||||
self.completion = completion
|
||||
}
|
||||
|
||||
func reject(error: any Error, task: URLSessionTask) {
|
||||
completion.resume(throwing: error)
|
||||
self.progress?.finish()
|
||||
self.completion.resume(with: .failure(error))
|
||||
}
|
||||
}
|
||||
|
||||
@ -837,19 +841,18 @@ private class DownloadTaskState: TaskState {
|
||||
|
||||
/// Also used for upload tasks, which are a subclass data tasks.
|
||||
private class DataTaskState: TaskState {
|
||||
typealias CompletionContinuation = CheckedContinuation<(URLResponse?, Data), any Error>
|
||||
|
||||
let pendingData = AtomicValue<Data>(Data(), lock: .init())
|
||||
let progressSource: OWSProgressSource?
|
||||
let completion: CompletionContinuation
|
||||
let progress: ProgressContinuation?
|
||||
let completion: DeferredContinuation<(URLResponse?, Data)>
|
||||
|
||||
init(progressSource: OWSProgressSource?, completion: CompletionContinuation) {
|
||||
self.progressSource = progressSource
|
||||
init(progress: ProgressContinuation?, completion: DeferredContinuation<(URLResponse?, Data)>) {
|
||||
self.progress = progress
|
||||
self.completion = completion
|
||||
}
|
||||
|
||||
func reject(error: any Error, task: URLSessionTask) {
|
||||
self.completion.resume(throwing: error)
|
||||
self.progress?.finish()
|
||||
self.completion.resume(with: .failure(error))
|
||||
}
|
||||
}
|
||||
|
||||
@ -859,7 +862,7 @@ private class WebSocketTaskState: TaskState {
|
||||
typealias OpenBlock = (String?) -> Void
|
||||
typealias CloseBlock = (Error) -> Void
|
||||
|
||||
var progressSource: OWSProgressSource? { nil }
|
||||
var progress: ProgressContinuation? { nil }
|
||||
let openBlock: OpenBlock
|
||||
let closeBlock: CloseBlock
|
||||
|
||||
|
||||
@ -181,7 +181,7 @@ struct UploadEndpointCDN2: UploadEndpoint {
|
||||
method: .put,
|
||||
headers: headers,
|
||||
requestData: uploadData,
|
||||
progress: progress,
|
||||
progressBlock: progress?.asProgressBlock() ?? { _, _ in },
|
||||
)
|
||||
switch response.responseStatusCode {
|
||||
case 200, 201:
|
||||
|
||||
@ -130,7 +130,7 @@ struct UploadEndpointCDN3: UploadEndpoint {
|
||||
method: method,
|
||||
headers: headers,
|
||||
requestData: uploadData,
|
||||
progress: progress,
|
||||
progressBlock: progress?.asProgressBlock() ?? { _, _ in },
|
||||
)
|
||||
|
||||
switch response.responseStatusCode {
|
||||
|
||||
@ -339,6 +339,14 @@ extension OWSProgressSource {
|
||||
|
||||
extension OWSProgressSource where Self: Sendable {
|
||||
|
||||
func asProgressBlock() -> OWSURLSession.ProgressBlock {
|
||||
return { completedByteCount, totalByteCount in
|
||||
if self.completedUnitCount < completedByteCount {
|
||||
self.incrementCompletedUnitCount(by: UInt64(completedByteCount) - self.completedUnitCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Given some block of asynchronous work, update progress
|
||||
/// on the current source periodically (every ``timeInterval`` seconds)
|
||||
/// until the work block completes.
|
||||
|
||||
@ -8,7 +8,7 @@ import Foundation
|
||||
|
||||
typealias PerformTSRequestBlock = (TSRequest) async throws -> HTTPResponse
|
||||
typealias PerformRequestBlock = (URLRequest) async throws -> HTTPResponse
|
||||
typealias PerformUploadBlock = (URLRequest, Data, OWSProgressSource?) async throws -> HTTPResponse
|
||||
typealias PerformUploadBlock = (URLRequest, Data, OWSURLSession.ProgressBlock) async throws -> HTTPResponse
|
||||
|
||||
enum MockRequestType {
|
||||
case uploadForm(PerformTSRequestBlock)
|
||||
@ -162,12 +162,12 @@ class AttachmentUploadManagerMockHelper {
|
||||
}
|
||||
}
|
||||
|
||||
mockURLSession.performUploadDataBlock = { request, data, progress in
|
||||
mockURLSession.performUploadDataBlock = { request, data, progressBlock in
|
||||
guard case let .uploadTask(requestBlock) = self.activeUploadRequestMocks.removeFirst() else {
|
||||
throw OWSAssertionError("Mock request missing")
|
||||
}
|
||||
self.capturedRequests.append(.uploadTask(request))
|
||||
return try await requestBlock(request, data, progress)
|
||||
return try await requestBlock(request, data, progressBlock)
|
||||
}
|
||||
|
||||
return insertMockAttachment(mockAttachment)
|
||||
@ -342,10 +342,15 @@ class AttachmentUploadManagerMockHelper {
|
||||
case networkTimeout
|
||||
}
|
||||
|
||||
func addUploadRequestMock(auth: String, location: String, type: UploadResultType, completedCount: UInt64? = nil) {
|
||||
enqueue(auth: auth, request: .uploadTask({ request, url, progress in
|
||||
func addUploadRequestMock(
|
||||
auth: String,
|
||||
location: String,
|
||||
type: UploadResultType,
|
||||
completedCount: (completedByteCount: Int64, totalByteCount: Int64)? = nil,
|
||||
) {
|
||||
enqueue(auth: auth, request: .uploadTask({ request, url, progressBlock in
|
||||
if let completedCount {
|
||||
progress?.incrementCompletedUnitCount(by: completedCount)
|
||||
await progressBlock(completedCount.completedByteCount, completedCount.totalByteCount)
|
||||
}
|
||||
switch type {
|
||||
case .networkTimeout:
|
||||
|
||||
@ -71,14 +71,14 @@ class _AttachmentUploadManager_NetworkManagerMock: NetworkManager {
|
||||
|
||||
public class _AttachmentUploadManager_OWSURLSessionMock: BaseOWSURLSessionMock {
|
||||
|
||||
public var performUploadDataBlock: ((URLRequest, Data, OWSProgressSource?) async throws -> HTTPResponse)?
|
||||
override public func performUpload(request: URLRequest, requestData: Data, progress: OWSProgressSource?) async throws -> HTTPResponse {
|
||||
return try await performUploadDataBlock!(request, requestData, progress)
|
||||
public var performUploadDataBlock: ((URLRequest, Data, OWSURLSession.ProgressBlock) async throws -> HTTPResponse)?
|
||||
override public func performUpload(request: URLRequest, requestData: Data, progressBlock: OWSURLSession.ProgressBlock) async throws -> HTTPResponse {
|
||||
return try await performUploadDataBlock!(request, requestData, progressBlock)
|
||||
}
|
||||
|
||||
public var performUploadFileBlock: ((URLRequest, URL, Bool, OWSProgressSource?) async throws -> HTTPResponse)?
|
||||
override public func performUpload(request: URLRequest, fileUrl: URL, ignoreAppExpiry: Bool, progress: OWSProgressSource?) async throws -> HTTPResponse {
|
||||
return try await performUploadFileBlock!(request, fileUrl, ignoreAppExpiry, progress)
|
||||
public var performUploadFileBlock: ((URLRequest, URL, Bool, OWSURLSession.ProgressBlock) async throws -> HTTPResponse)?
|
||||
override public func performUpload(request: URLRequest, fileUrl: URL, ignoreAppExpiry: Bool, progressBlock: OWSURLSession.ProgressBlock) async throws -> HTTPResponse {
|
||||
return try await performUploadFileBlock!(request, fileUrl, ignoreAppExpiry, progressBlock)
|
||||
}
|
||||
|
||||
public var performRequestBlock: ((URLRequest) async throws -> HTTPResponse)?
|
||||
|
||||
@ -372,25 +372,25 @@ struct AttachmentUploadManagerTests {
|
||||
// 1. Upload location request
|
||||
let attempt = helper.addUploadFormAndLocationRequestMock(cdn: cdn) { auth, _, location in
|
||||
// 2. Fail the upload with a server error
|
||||
helper.addUploadRequestMock(auth: auth, location: location, type: .failure(code: 500), completedCount: 20)
|
||||
helper.addUploadRequestMock(auth: auth, location: location, type: .failure(code: 500), completedCount: (20, 20))
|
||||
|
||||
// 3. Fetch the remote progress, but find none
|
||||
helper.addResumeProgressMock(cdn: cdn, auth: auth, location: location, type: .missingRange)
|
||||
|
||||
// 4. Fail the upload with a server error
|
||||
helper.addUploadRequestMock(auth: auth, location: location, type: .failure(code: 500), completedCount: 20)
|
||||
helper.addUploadRequestMock(auth: auth, location: location, type: .failure(code: 500), completedCount: (20, 20))
|
||||
|
||||
// 5. Fetch the remote progress, but find none
|
||||
helper.addResumeProgressMock(cdn: cdn, auth: auth, location: location, type: .missingRange)
|
||||
|
||||
// 6. Fail the upload with a server error
|
||||
helper.addUploadRequestMock(auth: auth, location: location, type: .failure(code: 500), completedCount: 20)
|
||||
helper.addUploadRequestMock(auth: auth, location: location, type: .failure(code: 500), completedCount: (20, 20))
|
||||
|
||||
// 7. Fetch the remote progress, but find none
|
||||
helper.addResumeProgressMock(cdn: cdn, auth: auth, location: location, type: .missingRange)
|
||||
|
||||
// 8. Succeed the upload
|
||||
helper.addUploadRequestMock(auth: auth, location: location, type: .success(200), completedCount: 20)
|
||||
helper.addUploadRequestMock(auth: auth, location: location, type: .success(200), completedCount: (20, 20))
|
||||
}
|
||||
|
||||
try await uploadManager.uploadTransitTierAttachment(attachmentId: attachmentID, progress: nil)
|
||||
@ -435,7 +435,7 @@ struct AttachmentUploadManagerTests {
|
||||
helper.addResumeProgressMock(cdn: cdn, auth: auth, location: location, type: .progress(count: 15))
|
||||
|
||||
// 8. Succeed the upload
|
||||
helper.addUploadRequestMock(auth: auth, location: location, type: .success(200), completedCount: 20)
|
||||
helper.addUploadRequestMock(auth: auth, location: location, type: .success(200), completedCount: (20, 20))
|
||||
}
|
||||
|
||||
try await uploadManager.uploadTransitTierAttachment(attachmentId: attachmentID, progress: nil)
|
||||
|
||||
@ -439,7 +439,7 @@ private class MockDownloadSession: BaseOWSURLSessionMock {
|
||||
|
||||
var performDownloadSource: ((URL) async throws -> OWSUrlDownloadResponse)?
|
||||
|
||||
override func performDownload(request: URLRequest, progress: OWSProgressSource?) async throws -> OWSUrlDownloadResponse {
|
||||
override func performDownload(request: URLRequest, progressBlock: OWSURLSession.ProgressBlock) async throws -> OWSUrlDownloadResponse {
|
||||
return try await performDownloadSource!(request.url!)
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user