Structure concurrency for OWSUrlSession’s progress

This commit is contained in:
Max Radermacher 2026-04-03 16:19:24 -05:00 committed by GitHub
parent 0213818227
commit 0758c26868
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 191 additions and 147 deletions

View File

@ -705,6 +705,7 @@
504861AB2EEB4D1500B13C49 /* SignalAttachment.swift in Sources */ = {isa = PBXBuildFile; fileRef = 34D913491F62D4A500722898 /* SignalAttachment.swift */; };
5049246F2F209DA8006469B3 /* OutgoingMessageRequestResponseSyncMessage.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5049246E2F209DA8006469B3 /* OutgoingMessageRequestResponseSyncMessage.swift */; };
5049FA2F28BEAABE00D6E099 /* ContactDiscoveryV2Operation.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5049FA2D28BEAABE00D6E099 /* ContactDiscoveryV2Operation.swift */; };
504B8BAC2F802FA2002B8AB9 /* DeferredContinuation.swift in Sources */ = {isa = PBXBuildFile; fileRef = 504B8BAB2F802FA2002B8AB9 /* DeferredContinuation.swift */; };
504F397C29D23B1700E849A6 /* ValidatedIncomingEnvelope.swift in Sources */ = {isa = PBXBuildFile; fileRef = 504F397B29D23B1700E849A6 /* ValidatedIncomingEnvelope.swift */; };
504F98B12EAFFAC600DF465B /* KyberPreKeyUseRecord.swift in Sources */ = {isa = PBXBuildFile; fileRef = 504F98B02EAFFAC600DF465B /* KyberPreKeyUseRecord.swift */; };
504F98B32EB0270A00DF465B /* SendMessageFailure.swift in Sources */ = {isa = PBXBuildFile; fileRef = 504F98B22EB0270A00DF465B /* SendMessageFailure.swift */; };
@ -4974,6 +4975,7 @@
50468F2A29EE19C300948E02 /* PhoneNumberChangedMessageInserterTest.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PhoneNumberChangedMessageInserterTest.swift; sourceTree = "<group>"; };
5049246E2F209DA8006469B3 /* OutgoingMessageRequestResponseSyncMessage.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OutgoingMessageRequestResponseSyncMessage.swift; sourceTree = "<group>"; };
5049FA2D28BEAABE00D6E099 /* ContactDiscoveryV2Operation.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ContactDiscoveryV2Operation.swift; sourceTree = "<group>"; };
504B8BAB2F802FA2002B8AB9 /* DeferredContinuation.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = DeferredContinuation.swift; sourceTree = "<group>"; };
504F397B29D23B1700E849A6 /* ValidatedIncomingEnvelope.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ValidatedIncomingEnvelope.swift; sourceTree = "<group>"; };
504F98B02EAFFAC600DF465B /* KyberPreKeyUseRecord.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = KyberPreKeyUseRecord.swift; sourceTree = "<group>"; };
504F98B22EB0270A00DF465B /* SendMessageFailure.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SendMessageFailure.swift; sourceTree = "<group>"; };
@ -9634,6 +9636,7 @@
50C0203B2CA4A61E00BDC4EF /* ConcurrentTaskQueue.swift */,
500AF3AE2C58366700CB9F4F /* CooperativeTimeout.swift */,
500AF3B02C58385600CB9F4F /* CooperativeTimeoutTest.swift */,
504B8BAB2F802FA2002B8AB9 /* DeferredContinuation.swift */,
505C2BC12E85D72B0009237F /* KeyedConcurrentTaskQueue.swift */,
50B0BCA22DFA17BE0076B680 /* Monitor.swift */,
50C0203D2CA4A7A500BDC4EF /* Retry.swift */,
@ -19225,6 +19228,7 @@
C14EC1AB2BAB57B900A4D064 /* DecryptingStreamTransform.swift in Sources */,
F9C5CD27289453B300548EEE /* DeepCopy.swift in Sources */,
F9C5CC11289453B300548EEE /* DefaultStickers.swift in Sources */,
504B8BAC2F802FA2002B8AB9 /* DeferredContinuation.swift in Sources */,
D91AC9322B61AD9A00814975 /* DeletedCallRecord.swift in Sources */,
D9CA8AB02B698DFF00787167 /* DeletedCallRecordExpirationJob.swift in Sources */,
D91AC9342B61C1F000814975 /* DeletedCallRecordStore.swift in Sources */,

View File

@ -416,7 +416,6 @@ private enum DebugLogUploader {
mimeType: mimeType,
textParts: textParts,
ignoreAppExpiry: true,
progress: nil,
)
let statusCode = response.responseStatusCode

View File

@ -11,20 +11,13 @@ import Foundation
/// example, when waiting for an event to occur, "cancellation" means "stop
/// waiting for the event to occur" and not "stop the event from occurring".
public struct CancellableContinuation<T>: Sendable {
private enum State {
case initial
case waiting(CheckedContinuation<T, Error>)
case completed(Result<T, Error>)
case consumed
}
private let state = AtomicValue<State>(State.initial, lock: .init())
private let deferredContinuation = DeferredContinuation<T>()
public init() {
}
func cancel() {
self.resume(with: .failure(CancellationError()))
self.deferredContinuation.resume(with: .failure(CancellationError()))
}
/// Resumes the continuation with `result`.
@ -32,49 +25,13 @@ public struct CancellableContinuation<T>: Sendable {
/// It's safe (and harmless) to call `resume` multiple times; redundant
/// invocations are ignored.
public func resume(with result: Result<T, Error>) {
let continuation = self.state.update { state -> CheckedContinuation<T, Error>? in
switch state {
case .initial:
state = .completed(result)
return nil
case .waiting(let continuation):
state = .consumed
return continuation
case .completed(_), .consumed:
// Ignore the new result.
return nil
}
}
if let continuation {
continuation.resume(with: result)
}
self.deferredContinuation.resume(with: result)
}
/// Waits for the result. Should only be called once per instance!
public func wait() async throws -> T {
try await withTaskCancellationHandler(
operation: {
return try await withCheckedThrowingContinuation { continuation in
let result = self.state.update { state -> Result<T, Error>? in
switch state {
case .initial:
state = .waiting(continuation)
return nil
case .completed(let result):
state = .consumed
return result
case .waiting(_), .consumed:
continuation.resume(throwing: OWSAssertionError(
"should not await a CancellableContinuation multiple times",
))
return nil
}
}
if let result {
continuation.resume(with: result)
}
}
},
operation: { try await self.deferredContinuation.wait() },
onCancel: { self.cancel() },
)
}

View File

@ -0,0 +1,68 @@
//
// Copyright 2026 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//
import Foundation
/// A container that stores a Result before creating a continuation.
public struct DeferredContinuation<T>: Sendable {
private enum State {
case initial
case waiting(CheckedContinuation<T, Error>)
case completed(Result<T, Error>)
case consumed
}
private let state = AtomicValue<State>(State.initial, lock: .init())
public init() {
}
/// Resumes the continuation with `result`.
///
/// It's safe (and harmless) to call `resume` multiple times; redundant
/// invocations are ignored.
public func resume(with result: Result<T, Error>) {
let continuation = self.state.update { state -> CheckedContinuation<T, Error>? in
switch state {
case .initial:
state = .completed(result)
return nil
case .waiting(let continuation):
state = .consumed
return continuation
case .completed(_), .consumed:
// Ignore the new result.
return nil
}
}
if let continuation {
continuation.resume(with: result)
}
}
/// Waits for the result. Should only be called once per instance!
public func wait() async throws -> T {
return try await withCheckedThrowingContinuation { continuation in
let result = self.state.update { state -> Result<T, Error>? in
switch state {
case .initial:
state = .waiting(continuation)
return nil
case .completed(let result):
state = .consumed
return result
case .waiting(_), .consumed:
continuation.resume(throwing: OWSAssertionError(
"should not await a DeferredContinuation multiple times",
))
return nil
}
}
if let result {
continuation.resume(with: result)
}
}
}
}

View File

@ -1731,7 +1731,7 @@ public class AttachmentDownloadManagerImpl: AttachmentDownloadManager {
return try await urlSession.performDownload(
requestUrl: requestUrl,
resumeData: resumeData,
progress: wrappedProgressSource,
progressBlock: wrappedProgressSource.asProgressBlock(),
)
}
downloadResponse = try await downloadTask!.value
@ -1741,7 +1741,7 @@ public class AttachmentDownloadManagerImpl: AttachmentDownloadManager {
urlPath,
method: .get,
headers: headers,
progress: wrappedProgressSource,
progressBlock: wrappedProgressSource.asProgressBlock(),
)
}
downloadResponse = try await downloadTask!.value

View File

@ -99,7 +99,7 @@ public class BaseOWSURLSessionMock: OWSURLSessionProtocol {
public func performUpload(
request: URLRequest,
requestData: Data,
progress: OWSProgressSource?,
progressBlock: OWSURLSession.ProgressBlock,
) async throws -> HTTPResponse {
// Want different behavior? Write a custom mock class
return HTTPResponse(
@ -114,7 +114,7 @@ public class BaseOWSURLSessionMock: OWSURLSessionProtocol {
request: URLRequest,
fileUrl: URL,
ignoreAppExpiry: Bool,
progress: OWSProgressSource?,
progressBlock: OWSURLSession.ProgressBlock,
) async throws -> HTTPResponse {
// Want different behavior? Write a custom mock class
return HTTPResponse(
@ -137,7 +137,7 @@ public class BaseOWSURLSessionMock: OWSURLSessionProtocol {
public func performDownload(
request: URLRequest,
progress: OWSProgressSource?,
progressBlock: OWSURLSession.ProgressBlock,
) async throws -> OWSUrlDownloadResponse {
// Want different behavior? Write a custom mock class
return OWSUrlDownloadResponse(
@ -149,7 +149,7 @@ public class BaseOWSURLSessionMock: OWSURLSessionProtocol {
public func performDownload(
requestUrl: URL,
resumeData: Data,
progress: OWSProgressSource?,
progressBlock: OWSURLSession.ProgressBlock,
) async throws -> OWSUrlDownloadResponse {
// Want different behavior? Write a custom mock class
return OWSUrlDownloadResponse(

View File

@ -127,14 +127,14 @@ public protocol OWSURLSessionProtocol: AnyObject {
func performUpload(
request: URLRequest,
requestData: Data,
progress: OWSProgressSource?,
progressBlock: OWSURLSession.ProgressBlock,
) async throws -> HTTPResponse
func performUpload(
request: URLRequest,
fileUrl: URL,
ignoreAppExpiry: Bool,
progress: OWSProgressSource?,
progressBlock: OWSURLSession.ProgressBlock,
) async throws -> HTTPResponse
func performRequest(
@ -145,12 +145,12 @@ public protocol OWSURLSessionProtocol: AnyObject {
func performDownload(
requestUrl: URL,
resumeData: Data,
progress: OWSProgressSource?,
progressBlock: OWSURLSession.ProgressBlock,
) async throws -> OWSUrlDownloadResponse
func performDownload(
request: URLRequest,
progress: OWSProgressSource?,
progressBlock: OWSURLSession.ProgressBlock,
) async throws -> OWSUrlDownloadResponse
func webSocketTask(
@ -188,10 +188,10 @@ public extension OWSURLSessionProtocol {
method: HTTPMethod,
headers: HttpHeaders = HttpHeaders(),
requestData: Data,
progress: OWSProgressSource? = nil,
progressBlock: OWSURLSession.ProgressBlock = { _, _ in },
) async throws -> HTTPResponse {
let request = try self.endpoint.buildRequest(urlString, method: method, headers: headers, body: requestData)
return try await self.performUpload(request: request, requestData: requestData, progress: progress)
return try await self.performUpload(request: request, requestData: requestData, progressBlock: progressBlock)
}
func performUpload(
@ -199,14 +199,14 @@ public extension OWSURLSessionProtocol {
method: HTTPMethod,
headers: HttpHeaders = HttpHeaders(),
fileUrl: URL,
progress: OWSProgressSource? = nil,
progressBlock: OWSURLSession.ProgressBlock = { _, _ in },
) async throws -> HTTPResponse {
let request = try self.endpoint.buildRequest(urlString, method: method, headers: headers)
return try await self.performUpload(
request: request,
fileUrl: fileUrl,
ignoreAppExpiry: false,
progress: progress,
progressBlock: progressBlock,
)
}
@ -230,10 +230,10 @@ public extension OWSURLSessionProtocol {
method: HTTPMethod,
headers: HttpHeaders = HttpHeaders(),
body: Data? = nil,
progress: OWSProgressSource? = nil,
progressBlock: OWSURLSession.ProgressBlock = { _, _ in },
) async throws -> OWSUrlDownloadResponse {
let request = try self.endpoint.buildRequest(urlString, method: method, headers: headers, body: body)
return try await self.performDownload(request: request, progress: progress)
return try await self.performDownload(request: request, progressBlock: progressBlock)
}
}
@ -249,7 +249,7 @@ extension OWSURLSessionProtocol {
mimeType: String,
textParts textPartsDictionary: OrderedDictionary<String, String>,
ignoreAppExpiry: Bool = false,
progress: OWSProgressSource? = nil,
progressBlock: OWSURLSession.ProgressBlock = { _, _ in },
) async throws -> HTTPResponse {
let multipartBodyFileURL = OWSFileSystem.temporaryFileUrl(
fileExtension: nil,
@ -289,7 +289,7 @@ extension OWSURLSessionProtocol {
request: request,
fileUrl: multipartBodyFileURL,
ignoreAppExpiry: ignoreAppExpiry,
progress: progress,
progressBlock: progressBlock,
)
}
}

View File

@ -11,6 +11,8 @@ public enum OWSURLSessionError: Error {
public class OWSURLSession: OWSURLSessionProtocol {
public typealias ProgressBlock = (_ completedByteCount: Int64, _ totalByteCount: Int64) async -> Void
// MARK: - OWSURLSessionProtocol conformance
public let endpoint: OWSURLSessionEndpoint
@ -149,12 +151,12 @@ public class OWSURLSession: OWSURLSessionProtocol {
public func performUpload(
request: URLRequest,
requestData: Data,
progress: OWSProgressSource?,
progressBlock: ProgressBlock,
) async throws -> HTTPResponse {
return try await performUpload(
request: request,
ignoreAppExpiry: false,
progress: progress,
progressBlock: progressBlock,
taskBlock: { self.session.uploadTask(with: request, from: requestData) },
)
}
@ -163,12 +165,12 @@ public class OWSURLSession: OWSURLSessionProtocol {
request: URLRequest,
fileUrl: URL,
ignoreAppExpiry: Bool,
progress: OWSProgressSource?,
progressBlock: ProgressBlock,
) async throws -> HTTPResponse {
return try await performUpload(
request: request,
ignoreAppExpiry: ignoreAppExpiry,
progress: progress,
progressBlock: progressBlock,
taskBlock: { self.session.uploadTask(with: request, fromFile: fileUrl) },
)
}
@ -182,9 +184,11 @@ public class OWSURLSession: OWSURLSessionProtocol {
let requestConfig = self.requestConfig(requestUrl: request.url!)
let task = session.dataTask(with: request)
let (urlResponse, responseData) = try await runTask(task, taskState: {
return DataTaskState(progressSource: nil, completion: $0)
})
let (urlResponse, responseData) = try await runTask(
task,
taskState: { DataTaskState(progress: $0, completion: $1) },
progressBlock: { _, _ in },
)
return try await handleDataResult(
urlResponse: urlResponse,
@ -196,13 +200,13 @@ public class OWSURLSession: OWSURLSessionProtocol {
public func performDownload(
request: URLRequest,
progress: OWSProgressSource?,
progressBlock: ProgressBlock,
) async throws -> OWSUrlDownloadResponse {
let request = prepareRequest(request: request)
guard let requestUrl = request.url else {
throw OWSAssertionError("Request missing url.")
}
return try await performDownload(requestUrl: requestUrl, progress: progress) {
return try await performDownload(requestUrl: requestUrl, progressBlock: progressBlock) {
// Don't use a completion block or the delegate will be ignored for download tasks.
return self.session.downloadTask(with: request)
}
@ -211,9 +215,9 @@ public class OWSURLSession: OWSURLSessionProtocol {
public func performDownload(
requestUrl: URL,
resumeData: Data,
progress: OWSProgressSource?,
progressBlock: ProgressBlock,
) async throws -> OWSUrlDownloadResponse {
return try await performDownload(requestUrl: requestUrl, progress: progress) {
return try await performDownload(requestUrl: requestUrl, progressBlock: progressBlock) {
// Don't use a completion block or the delegate will be ignored for download tasks.
return self.session.downloadTask(withResumeData: resumeData)
}
@ -452,7 +456,7 @@ public class OWSURLSession: OWSURLSessionProtocol {
do {
rawRequest.logger.info("Sending… -> \(rawRequest)")
let response = try await performUpload(request: request, requestData: requestBody, progress: nil)
let response = try await performUpload(request: request, requestData: requestBody, progressBlock: { _, _ in })
rawRequest.logger.info("HTTP \(response.responseStatusCode) <- \(rawRequest)")
return response
} catch where error.httpStatusCode != nil {
@ -467,7 +471,7 @@ public class OWSURLSession: OWSURLSessionProtocol {
private func performUpload(
request: URLRequest,
ignoreAppExpiry: Bool,
progress: OWSProgressSource?,
progressBlock: ProgressBlock,
taskBlock: () -> URLSessionUploadTask,
) async throws -> HTTPResponse {
if !ignoreAppExpiry, DependenciesBridge.shared.appExpiry.isExpired(now: Date()) {
@ -481,9 +485,11 @@ public class OWSURLSession: OWSURLSessionProtocol {
let urlResponse: URLResponse?
let responseData: Data
do {
(urlResponse, responseData) = try await runTask(task, taskState: {
return DataTaskState(progressSource: progress, completion: $0)
})
(urlResponse, responseData) = try await runTask(
task,
taskState: { DataTaskState(progress: $0, completion: $1) },
progressBlock: progressBlock,
)
} catch {
throw handleError(error, originalRequest: task.originalRequest, requestConfig: requestConfig)
}
@ -497,7 +503,7 @@ public class OWSURLSession: OWSURLSessionProtocol {
private func performDownload(
requestUrl: URL,
progress: OWSProgressSource?,
progressBlock: ProgressBlock,
taskBlock: () -> URLSessionDownloadTask,
) async throws -> OWSUrlDownloadResponse {
let appExpiry = DependenciesBridge.shared.appExpiry
@ -508,9 +514,11 @@ public class OWSURLSession: OWSURLSessionProtocol {
let requestConfig = self.requestConfig(requestUrl: requestUrl)
let task = taskBlock()
let (urlResponse, downloadUrl) = try await runTask(task, taskState: {
return DownloadTaskState(progressSource: progress, completion: $0)
})
let (urlResponse, downloadUrl) = try await runTask(
task,
taskState: { DownloadTaskState(progress: $0, completion: $1) },
progressBlock: progressBlock,
)
return try await handleDownloadResult(
urlResponse: urlResponse,
@ -520,7 +528,11 @@ public class OWSURLSession: OWSURLSessionProtocol {
)
}
private func runTask<T>(_ task: URLSessionTask, taskState: (CheckedContinuation<T, any Error>) -> some TaskState) async throws -> T {
private func runTask<T>(
_ task: URLSessionTask,
taskState: (TaskState.ProgressContinuation, DeferredContinuation<T>) -> some TaskState,
progressBlock: ProgressBlock,
) async throws -> T {
// It's possible for operation and onCancel to race one another, so we use
// a counter to ensure that cancellation happens after addTask is invoked.
// (You can trigger this by sending a request from a canceled Task.)
@ -528,8 +540,9 @@ public class OWSURLSession: OWSURLSessionProtocol {
return try await withTaskCancellationHandler(
operation: {
return try await withCheckedThrowingContinuation { continuation in
self.addTask(task, taskState: taskState(continuation))
let completion = DeferredContinuation<T>()
let progressStream = AsyncStream(bufferingPolicy: .bufferingNewest(1)) { continuation in
self.addTask(task, taskState: taskState(continuation, completion))
// If cancel was already called, cancel it now.
if cancelState.increment() == 2 {
task.cancel()
@ -537,6 +550,10 @@ public class OWSURLSession: OWSURLSessionProtocol {
task.resume()
}
}
for await progressUpdate in progressStream {
await progressBlock(progressUpdate.completedByteCount, progressUpdate.totalByteCount)
}
return try await completion.wait()
},
onCancel: {
// If the task was already added, cancel it now.
@ -569,9 +586,9 @@ public class OWSURLSession: OWSURLSessionProtocol {
}
}
private func progressSource(forTask task: URLSessionTask) -> OWSProgressSource? {
private func progress(forTask task: URLSessionTask) -> TaskState.ProgressContinuation? {
return updateTaskStates {
return $0[task.taskIdentifier]?.progressSource
return $0[task.taskIdentifier]?.progress
}
}
@ -606,7 +623,8 @@ public class OWSURLSession: OWSURLSessionProtocol {
owsFailDebug("Missing TaskState.")
return
}
taskState.completion.resume(returning: (task.response, downloadUrl))
taskState.progress?.finish()
taskState.completion.resume(with: .success((task.response, downloadUrl)))
}
private func dataTaskDidSucceed(_ task: URLSessionTask) {
@ -615,7 +633,8 @@ public class OWSURLSession: OWSURLSessionProtocol {
return
}
let responseData = taskState.pendingData.get()
taskState.completion.resume(returning: (task.response, responseData))
taskState.progress?.finish()
taskState.completion.resume(with: .success((task.response, responseData)))
}
private func taskDidFail(_ task: URLSessionTask, error: Error) {
@ -703,13 +722,8 @@ extension OWSURLSession {
}
func urlSession(_ session: URLSession, task: URLSessionTask, didSendBodyData bytesSent: Int64, totalBytesSent: Int64, totalBytesExpectedToSend: Int64) {
guard let progressSource = self.progressSource(forTask: task) else {
return
}
// TODO: We could check for NSURLSessionTransferSizeUnknown here.
if progressSource.completedUnitCount < totalBytesSent {
progressSource.incrementCompletedUnitCount(by: UInt64(totalBytesSent) - progressSource.completedUnitCount)
}
self.progress(forTask: task)?.yield((totalBytesSent, totalBytesExpectedToSend))
}
func urlSession(_ session: URLSession, downloadTask: URLSessionDownloadTask, didFinishDownloadingTo location: URL) {
@ -750,12 +764,7 @@ extension OWSURLSession {
taskDidFail(downloadTask, error: OWSURLSessionError.responseTooLarge)
return
}
guard let progressSource = self.progressSource(forTask: downloadTask) else {
return
}
if progressSource.completedUnitCount < totalBytesWritten {
progressSource.incrementCompletedUnitCount(by: UInt64(totalBytesWritten) - progressSource.completedUnitCount)
}
self.progress(forTask: downloadTask)?.yield((totalBytesWritten, totalBytesExpectedToWrite))
}
func urlSession(
@ -768,12 +777,7 @@ extension OWSURLSession {
taskDidFail(downloadTask, error: OWSURLSessionError.responseTooLarge)
return
}
guard let progressSource = self.progressSource(forTask: downloadTask) else {
return
}
if progressSource.completedUnitCount < fileOffset {
progressSource.incrementCompletedUnitCount(by: UInt64(fileOffset) - progressSource.completedUnitCount)
}
self.progress(forTask: downloadTask)?.yield((fileOffset, expectedTotalBytes))
}
func urlSession(
@ -811,25 +815,25 @@ extension OWSURLSession {
// MARK: - TaskState
private protocol TaskState {
typealias ProgressBlock = (URLSessionTask, Progress) -> Void
var progressSource: OWSProgressSource? { get }
typealias ProgressContinuation = AsyncStream<(completedByteCount: Int64, totalByteCount: Int64)>.Continuation
var progress: ProgressContinuation? { get }
func reject(error: any Error, task: URLSessionTask)
}
// MARK: - DownloadTaskState
private class DownloadTaskState: TaskState {
typealias CompletionContinuation = CheckedContinuation<(URLResponse?, URL), any Error>
let progressSource: OWSProgressSource?
let completion: CompletionContinuation
let progress: ProgressContinuation?
let completion: DeferredContinuation<(URLResponse?, URL)>
init(progressSource: OWSProgressSource?, completion: CompletionContinuation) {
self.progressSource = progressSource
init(progress: ProgressContinuation, completion: DeferredContinuation<(URLResponse?, URL)>) {
self.progress = progress
self.completion = completion
}
func reject(error: any Error, task: URLSessionTask) {
completion.resume(throwing: error)
self.progress?.finish()
self.completion.resume(with: .failure(error))
}
}
@ -837,19 +841,18 @@ private class DownloadTaskState: TaskState {
/// Also used for upload tasks, which are a subclass data tasks.
private class DataTaskState: TaskState {
typealias CompletionContinuation = CheckedContinuation<(URLResponse?, Data), any Error>
let pendingData = AtomicValue<Data>(Data(), lock: .init())
let progressSource: OWSProgressSource?
let completion: CompletionContinuation
let progress: ProgressContinuation?
let completion: DeferredContinuation<(URLResponse?, Data)>
init(progressSource: OWSProgressSource?, completion: CompletionContinuation) {
self.progressSource = progressSource
init(progress: ProgressContinuation?, completion: DeferredContinuation<(URLResponse?, Data)>) {
self.progress = progress
self.completion = completion
}
func reject(error: any Error, task: URLSessionTask) {
self.completion.resume(throwing: error)
self.progress?.finish()
self.completion.resume(with: .failure(error))
}
}
@ -859,7 +862,7 @@ private class WebSocketTaskState: TaskState {
typealias OpenBlock = (String?) -> Void
typealias CloseBlock = (Error) -> Void
var progressSource: OWSProgressSource? { nil }
var progress: ProgressContinuation? { nil }
let openBlock: OpenBlock
let closeBlock: CloseBlock

View File

@ -181,7 +181,7 @@ struct UploadEndpointCDN2: UploadEndpoint {
method: .put,
headers: headers,
requestData: uploadData,
progress: progress,
progressBlock: progress?.asProgressBlock() ?? { _, _ in },
)
switch response.responseStatusCode {
case 200, 201:

View File

@ -130,7 +130,7 @@ struct UploadEndpointCDN3: UploadEndpoint {
method: method,
headers: headers,
requestData: uploadData,
progress: progress,
progressBlock: progress?.asProgressBlock() ?? { _, _ in },
)
switch response.responseStatusCode {

View File

@ -339,6 +339,14 @@ extension OWSProgressSource {
extension OWSProgressSource where Self: Sendable {
func asProgressBlock() -> OWSURLSession.ProgressBlock {
return { completedByteCount, totalByteCount in
if self.completedUnitCount < completedByteCount {
self.incrementCompletedUnitCount(by: UInt64(completedByteCount) - self.completedUnitCount)
}
}
}
/// Given some block of asynchronous work, update progress
/// on the current source periodically (every ``timeInterval`` seconds)
/// until the work block completes.

View File

@ -8,7 +8,7 @@ import Foundation
typealias PerformTSRequestBlock = (TSRequest) async throws -> HTTPResponse
typealias PerformRequestBlock = (URLRequest) async throws -> HTTPResponse
typealias PerformUploadBlock = (URLRequest, Data, OWSProgressSource?) async throws -> HTTPResponse
typealias PerformUploadBlock = (URLRequest, Data, OWSURLSession.ProgressBlock) async throws -> HTTPResponse
enum MockRequestType {
case uploadForm(PerformTSRequestBlock)
@ -162,12 +162,12 @@ class AttachmentUploadManagerMockHelper {
}
}
mockURLSession.performUploadDataBlock = { request, data, progress in
mockURLSession.performUploadDataBlock = { request, data, progressBlock in
guard case let .uploadTask(requestBlock) = self.activeUploadRequestMocks.removeFirst() else {
throw OWSAssertionError("Mock request missing")
}
self.capturedRequests.append(.uploadTask(request))
return try await requestBlock(request, data, progress)
return try await requestBlock(request, data, progressBlock)
}
return insertMockAttachment(mockAttachment)
@ -342,10 +342,15 @@ class AttachmentUploadManagerMockHelper {
case networkTimeout
}
func addUploadRequestMock(auth: String, location: String, type: UploadResultType, completedCount: UInt64? = nil) {
enqueue(auth: auth, request: .uploadTask({ request, url, progress in
func addUploadRequestMock(
auth: String,
location: String,
type: UploadResultType,
completedCount: (completedByteCount: Int64, totalByteCount: Int64)? = nil,
) {
enqueue(auth: auth, request: .uploadTask({ request, url, progressBlock in
if let completedCount {
progress?.incrementCompletedUnitCount(by: completedCount)
await progressBlock(completedCount.completedByteCount, completedCount.totalByteCount)
}
switch type {
case .networkTimeout:

View File

@ -71,14 +71,14 @@ class _AttachmentUploadManager_NetworkManagerMock: NetworkManager {
public class _AttachmentUploadManager_OWSURLSessionMock: BaseOWSURLSessionMock {
public var performUploadDataBlock: ((URLRequest, Data, OWSProgressSource?) async throws -> HTTPResponse)?
override public func performUpload(request: URLRequest, requestData: Data, progress: OWSProgressSource?) async throws -> HTTPResponse {
return try await performUploadDataBlock!(request, requestData, progress)
public var performUploadDataBlock: ((URLRequest, Data, OWSURLSession.ProgressBlock) async throws -> HTTPResponse)?
override public func performUpload(request: URLRequest, requestData: Data, progressBlock: OWSURLSession.ProgressBlock) async throws -> HTTPResponse {
return try await performUploadDataBlock!(request, requestData, progressBlock)
}
public var performUploadFileBlock: ((URLRequest, URL, Bool, OWSProgressSource?) async throws -> HTTPResponse)?
override public func performUpload(request: URLRequest, fileUrl: URL, ignoreAppExpiry: Bool, progress: OWSProgressSource?) async throws -> HTTPResponse {
return try await performUploadFileBlock!(request, fileUrl, ignoreAppExpiry, progress)
public var performUploadFileBlock: ((URLRequest, URL, Bool, OWSURLSession.ProgressBlock) async throws -> HTTPResponse)?
override public func performUpload(request: URLRequest, fileUrl: URL, ignoreAppExpiry: Bool, progressBlock: OWSURLSession.ProgressBlock) async throws -> HTTPResponse {
return try await performUploadFileBlock!(request, fileUrl, ignoreAppExpiry, progressBlock)
}
public var performRequestBlock: ((URLRequest) async throws -> HTTPResponse)?

View File

@ -372,25 +372,25 @@ struct AttachmentUploadManagerTests {
// 1. Upload location request
let attempt = helper.addUploadFormAndLocationRequestMock(cdn: cdn) { auth, _, location in
// 2. Fail the upload with a server error
helper.addUploadRequestMock(auth: auth, location: location, type: .failure(code: 500), completedCount: 20)
helper.addUploadRequestMock(auth: auth, location: location, type: .failure(code: 500), completedCount: (20, 20))
// 3. Fetch the remote progress, but find none
helper.addResumeProgressMock(cdn: cdn, auth: auth, location: location, type: .missingRange)
// 4. Fail the upload with a server error
helper.addUploadRequestMock(auth: auth, location: location, type: .failure(code: 500), completedCount: 20)
helper.addUploadRequestMock(auth: auth, location: location, type: .failure(code: 500), completedCount: (20, 20))
// 5. Fetch the remote progress, but find none
helper.addResumeProgressMock(cdn: cdn, auth: auth, location: location, type: .missingRange)
// 6. Fail the upload with a server error
helper.addUploadRequestMock(auth: auth, location: location, type: .failure(code: 500), completedCount: 20)
helper.addUploadRequestMock(auth: auth, location: location, type: .failure(code: 500), completedCount: (20, 20))
// 7. Fetch the remote progress, but find none
helper.addResumeProgressMock(cdn: cdn, auth: auth, location: location, type: .missingRange)
// 8. Succeed the upload
helper.addUploadRequestMock(auth: auth, location: location, type: .success(200), completedCount: 20)
helper.addUploadRequestMock(auth: auth, location: location, type: .success(200), completedCount: (20, 20))
}
try await uploadManager.uploadTransitTierAttachment(attachmentId: attachmentID, progress: nil)
@ -435,7 +435,7 @@ struct AttachmentUploadManagerTests {
helper.addResumeProgressMock(cdn: cdn, auth: auth, location: location, type: .progress(count: 15))
// 8. Succeed the upload
helper.addUploadRequestMock(auth: auth, location: location, type: .success(200), completedCount: 20)
helper.addUploadRequestMock(auth: auth, location: location, type: .success(200), completedCount: (20, 20))
}
try await uploadManager.uploadTransitTierAttachment(attachmentId: attachmentID, progress: nil)

View File

@ -439,7 +439,7 @@ private class MockDownloadSession: BaseOWSURLSessionMock {
var performDownloadSource: ((URL) async throws -> OWSUrlDownloadResponse)?
override func performDownload(request: URLRequest, progress: OWSProgressSource?) async throws -> OWSUrlDownloadResponse {
override func performDownload(request: URLRequest, progressBlock: OWSURLSession.ProgressBlock) async throws -> OWSUrlDownloadResponse {
return try await performDownloadSource!(request.url!)
}
}