diff --git a/Pods b/Pods index 771fe9d0e2..48f32a61cb 160000 --- a/Pods +++ b/Pods @@ -1 +1 @@ -Subproject commit 771fe9d0e2b0b45c3fe39edcd28c9cd93038e07b +Subproject commit 48f32a61cb1d3dea373cbff7308c234667a1d437 diff --git a/SignalServiceKit/src/Account/TSAccountManager.h b/SignalServiceKit/src/Account/TSAccountManager.h index 58ea68b794..a8eb9e1029 100644 --- a/SignalServiceKit/src/Account/TSAccountManager.h +++ b/SignalServiceKit/src/Account/TSAccountManager.h @@ -172,6 +172,7 @@ NSString *NSStringForOWSRegistrationState(OWSRegistrationState value); #ifdef TESTABLE_BUILD - (void)registerForTestsWithLocalNumber:(NSString *)localNumber uuid:(NSUUID *)uuid; +- (void)registerForTestsWithLocalNumber:(NSString *)localNumber uuid:(NSUUID *)uuid pni:(NSUUID *_Nullable)pni; #endif @end diff --git a/SignalServiceKit/src/Account/TSAccountManager.m b/SignalServiceKit/src/Account/TSAccountManager.m index 5434c57168..80c2d7d416 100644 --- a/SignalServiceKit/src/Account/TSAccountManager.m +++ b/SignalServiceKit/src/Account/TSAccountManager.m @@ -954,6 +954,11 @@ NSString *NSStringForOWSRegistrationState(OWSRegistrationState value) } - (void)registerForTestsWithLocalNumber:(NSString *)localNumber uuid:(NSUUID *)uuid +{ + [self registerForTestsWithLocalNumber:localNumber uuid:uuid pni:nil]; +} + +- (void)registerForTestsWithLocalNumber:(NSString *)localNumber uuid:(NSUUID *)uuid pni:(NSUUID *_Nullable)pni { OWSAssertDebug(SSKFeatureFlags.storageMode == StorageModeGrdbTests); OWSAssertDebug(CurrentAppContext().isRunningTests); @@ -961,7 +966,7 @@ NSString *NSStringForOWSRegistrationState(OWSRegistrationState value) OWSAssertDebug(uuid != nil); DatabaseStorageWrite(self.databaseStorage, ^(SDSAnyWriteTransaction *transaction) { - [self storeLocalNumber:localNumber aci:uuid pni:nil transaction:transaction]; + [self storeLocalNumber:localNumber aci:uuid pni:pni transaction:transaction]; }); } diff --git a/SignalServiceKit/src/Messages/MessageProcessor.swift b/SignalServiceKit/src/Messages/MessageProcessor.swift index 6d527d0f40..024405b0a0 100644 --- a/SignalServiceKit/src/Messages/MessageProcessor.swift +++ b/SignalServiceKit/src/Messages/MessageProcessor.swift @@ -681,6 +681,7 @@ public class PendingEnvelopes { public enum MessageProcessingError: Error { case wrongDestinationUuid + case invalidMessageTypeForDestinationUuid case duplicatePendingEnvelope case blockedSender } diff --git a/SignalServiceKit/src/Messages/OWSMessageDecrypter.swift b/SignalServiceKit/src/Messages/OWSMessageDecrypter.swift index b302e407f8..c97b9f9c7a 100644 --- a/SignalServiceKit/src/Messages/OWSMessageDecrypter.swift +++ b/SignalServiceKit/src/Messages/OWSMessageDecrypter.swift @@ -8,16 +8,19 @@ public struct OWSMessageDecryptResult: Dependencies { public let envelope: SSKProtoEnvelope public let envelopeData: Data? public let plaintextData: Data? + public let identity: OWSIdentity fileprivate init( envelope: SSKProtoEnvelope, envelopeData: Data?, plaintextData: Data?, + identity: OWSIdentity, transaction: SDSAnyWriteTransaction ) { self.envelope = envelope self.envelopeData = envelopeData self.plaintextData = plaintextData + self.identity = identity guard let sourceAddress = envelope.sourceAddress, sourceAddress.isValid else { owsFailDebug("missing source address") @@ -71,6 +74,26 @@ public class OWSMessageDecrypter: OWSMessageHandler { } } + private func localIdentity(forDestinationUuidString destinationUuidString: String?, + transaction: SDSAnyReadTransaction) throws -> OWSIdentity { + guard let destinationUuidString = destinationUuidString else { + return .aci + } + guard let destinationUuid = UUID(uuidString: destinationUuidString) else { + throw OWSAssertionError("incoming envelope has invalid destinationUuid: \(destinationUuidString)") + } + + switch destinationUuid { + case tsAccountManager.uuid(with: transaction): + return .aci + case tsAccountManager.pni(with: transaction): + return .pni + default: + // PNI TODO: Handle past PNIs? + throw MessageProcessingError.wrongDestinationUuid + } + } + public func decryptEnvelope(_ envelope: SSKProtoEnvelope, envelopeData: Data?, transaction: SDSAnyWriteTransaction) -> Result { @@ -100,26 +123,44 @@ public class OWSMessageDecrypter: OWSMessageHandler { } } + let identity: OWSIdentity + do { + identity = try localIdentity(forDestinationUuidString: envelope.destinationUuid, + transaction: transaction) + // Check expected envelope types. + switch (identity, envelope.unwrappedType) { + case (.aci, _): + break + case (.pni, .prekeyBundle), (.pni, .receipt): + break + default: + throw MessageProcessingError.invalidMessageTypeForDestinationUuid + } + } catch { + return .failure(error) + } + let plaintextDataOrError: Result switch envelope.unwrappedType { case .ciphertext: - plaintextDataOrError = decrypt(envelope, cipherType: .whisper, transaction: transaction) + plaintextDataOrError = decrypt(envelope, sentTo: identity, cipherType: .whisper, transaction: transaction) case .prekeyBundle: TSPreKeyManager.checkPreKeysIfNecessary() - plaintextDataOrError = decrypt(envelope, cipherType: .preKey, transaction: transaction) + plaintextDataOrError = decrypt(envelope, sentTo: identity, cipherType: .preKey, transaction: transaction) case .receipt, .keyExchange, .unknown: return .success(OWSMessageDecryptResult( envelope: envelope, envelopeData: envelopeData, plaintextData: nil, + identity: identity, transaction: transaction )) case .unidentifiedSender: - return decryptUnidentifiedSenderEnvelope(envelope, transaction: transaction) + return decryptUnidentifiedSenderEnvelope(envelope, sentTo: identity, transaction: transaction) case .senderkeyMessage: - plaintextDataOrError = decrypt(envelope, cipherType: .senderKey, transaction: transaction) + plaintextDataOrError = decrypt(envelope, sentTo: identity, cipherType: .senderKey, transaction: transaction) case .plaintextContent: - plaintextDataOrError = decrypt(envelope, cipherType: .plaintext, transaction: transaction) + plaintextDataOrError = decrypt(envelope, sentTo: identity, cipherType: .plaintext, transaction: transaction) @unknown default: Logger.warn("Received unhandled envelope type: \(envelope.unwrappedType)") return .failure(OWSGenericError("Received unhandled envelope type: \(envelope.unwrappedType)")) @@ -130,6 +171,7 @@ public class OWSMessageDecrypter: OWSMessageHandler { envelope: envelope, envelopeData: envelopeData, plaintextData: $0, + identity: identity, transaction: transaction ) } @@ -264,6 +306,7 @@ public class OWSMessageDecrypter: OWSMessageHandler { private func processError( _ error: Error, envelope: SSKProtoEnvelope, + sentTo identity: OWSIdentity, untrustedGroupId: Data?, cipherType: CiphertextMessage.MessageType, contentHint: SealedSenderContentHint, @@ -316,7 +359,8 @@ public class OWSMessageDecrypter: OWSMessageHandler { address: sourceAddress, transaction: transaction) let contentSupportsResend = envelopeContentSupportsResend(envelope: envelope, cipherType: cipherType, transaction: transaction) - let supportsModernResend = remoteUserSupportsSenderKey && localUserSupportsSenderKey && contentSupportsResend + let supportsModernResend = + (identity == .aci) && remoteUserSupportsSenderKey && localUserSupportsSenderKey && contentSupportsResend if supportsModernResend && !RemoteConfig.messageResendKillSwitch { Logger.info("Performing modern resend of \(contentHint) content with timestamp \(envelope.timestamp)") @@ -349,7 +393,7 @@ public class OWSMessageDecrypter: OWSMessageHandler { cipherType: cipherType, failedEnvelopeGroupId: untrustedGroupId, transaction: transaction) - } else { + } else if identity == .aci { Logger.info("Performing legacy session reset of \(contentHint) content with timestamp \(envelope.timestamp)") let didReset = resetSessionIfNecessary( @@ -363,6 +407,12 @@ public class OWSMessageDecrypter: OWSMessageHandler { } else { errorMessage = nil } + } else { + Logger.info("Not resetting or requesting resend of message sent to \(identity)") + errorMessage = TSErrorMessage.failedDecryption( + for: envelope, + untrustedGroupId: untrustedGroupId, + with: transaction) } } else { owsFailDebug("Received envelope missing UUID \(sourceAddress).\(envelope.sourceDevice)") @@ -493,6 +543,7 @@ public class OWSMessageDecrypter: OWSMessageHandler { } private func decrypt(_ envelope: SSKProtoEnvelope, + sentTo identity: OWSIdentity, cipherType: CiphertextMessage.MessageType, transaction: SDSAnyWriteTransaction) -> Result { @@ -519,8 +570,7 @@ public class OWSMessageDecrypter: OWSMessageHandler { } let protocolAddress = try ProtocolAddress(from: sourceAddress, deviceId: deviceId) - // PNI TODO: make this dependent on destinationUuid - let signalProtocolStore = signalProtocolStore(for: .aci) + let signalProtocolStore = signalProtocolStore(for: identity) let plaintext: [UInt8] switch cipherType { @@ -529,7 +579,8 @@ public class OWSMessageDecrypter: OWSMessageHandler { plaintext = try signalDecrypt(message: message, from: protocolAddress, sessionStore: signalProtocolStore.sessionStore, - identityStore: identityManager.store(for: .aci, transaction: transaction), + identityStore: identityManager.store(for: identity, + transaction: transaction), context: transaction) sendReactiveProfileKeyIfNecessary(address: sourceAddress, transaction: transaction) case .preKey: @@ -537,7 +588,7 @@ public class OWSMessageDecrypter: OWSMessageHandler { plaintext = try signalDecryptPreKey(message: message, from: protocolAddress, sessionStore: signalProtocolStore.sessionStore, - identityStore: identityManager.store(for: .aci, + identityStore: identityManager.store(for: identity, transaction: transaction), preKeyStore: signalProtocolStore.preKeyStore, signedPreKeyStore: signalProtocolStore.signedPreKeyStore, @@ -568,6 +619,7 @@ public class OWSMessageDecrypter: OWSMessageHandler { let wrappedError = processError( error, envelope: envelope, + sentTo: identity, untrustedGroupId: nil, cipherType: cipherType, contentHint: .default, @@ -627,7 +679,11 @@ public class OWSMessageDecrypter: OWSMessageHandler { } } - private func decryptUnidentifiedSenderEnvelope(_ envelope: SSKProtoEnvelope, transaction: SDSAnyWriteTransaction) -> Result { + private func decryptUnidentifiedSenderEnvelope( + _ envelope: SSKProtoEnvelope, + sentTo identity: OWSIdentity, + transaction: SDSAnyWriteTransaction + ) -> Result { guard let encryptedData = envelope.content else { return .failure(OWSAssertionError("UD Envelope is missing content.")) } @@ -640,8 +696,7 @@ public class OWSMessageDecrypter: OWSMessageHandler { return .failure(OWSAssertionError("Invalid serverTimestamp.")) } - // PNI TODO: make this dependent on destinationUuid - let signalProtocolStore = Self.signalProtocolStore(for: .aci) + let signalProtocolStore = Self.signalProtocolStore(for: identity) let cipher: SMKSecretSessionCipher do { @@ -649,7 +704,7 @@ public class OWSMessageDecrypter: OWSMessageHandler { sessionStore: signalProtocolStore.sessionStore, preKeyStore: signalProtocolStore.preKeyStore, signedPreKeyStore: signalProtocolStore.signedPreKeyStore, - identityStore: identityManager.store(for: .aci, transaction: transaction), + identityStore: identityManager.store(for: identity, transaction: transaction), senderKeyStore: Self.senderKeyStore ) } catch { @@ -667,8 +722,9 @@ public class OWSMessageDecrypter: OWSMessageHandler { ) } catch let outerError as SecretSessionKnownSenderError { return .failure(handleUnidentifiedSenderDecryptionError( - error: outerError.underlyingError, + outerError.underlyingError, envelope: envelope.buildIdentifiedCopy(using: outerError), + sentTo: .aci, untrustedGroupId: outerError.groupId, cipherType: outerError.cipherType, contentHint: SealedSenderContentHint(outerError.contentHint), @@ -676,8 +732,9 @@ public class OWSMessageDecrypter: OWSMessageHandler { ) } catch { return .failure(handleUnidentifiedSenderDecryptionError( - error: error, + error, envelope: envelope, + sentTo: .aci, untrustedGroupId: nil, cipherType: .plaintext, contentHint: .default, @@ -721,13 +778,15 @@ public class OWSMessageDecrypter: OWSMessageHandler { envelope: identifiedEnvelope, envelopeData: nil, plaintextData: plaintextData, + identity: identity, transaction: transaction )) } func handleUnidentifiedSenderDecryptionError( - error: Error, + _ error: Error, envelope: SSKProtoEnvelope, + sentTo identity: OWSIdentity, untrustedGroupId: Data?, cipherType: CiphertextMessage.MessageType, contentHint: SealedSenderContentHint, @@ -739,6 +798,7 @@ public class OWSMessageDecrypter: OWSMessageHandler { } else if isSignalClientError(error) { return processError(error, envelope: envelope, + sentTo: identity, untrustedGroupId: untrustedGroupId, cipherType: cipherType, contentHint: contentHint, diff --git a/SignalServiceKit/src/TestUtils/NoopNotificationsManager.swift b/SignalServiceKit/src/TestUtils/NoopNotificationsManager.swift index 5e161f48a8..74f5141701 100644 --- a/SignalServiceKit/src/TestUtils/NoopNotificationsManager.swift +++ b/SignalServiceKit/src/TestUtils/NoopNotificationsManager.swift @@ -1,9 +1,10 @@ // -// Copyright (c) 2021 Open Whisper Systems. All rights reserved. +// Copyright (c) 2022 Open Whisper Systems. All rights reserved. // @objc public class NoopNotificationsManager: NSObject, NotificationsProtocol { + public var expectErrors: Bool = false public func notifyUser(forIncomingMessage incomingMessage: TSIncomingMessage, thread: TSThread, @@ -37,7 +38,7 @@ public class NoopNotificationsManager: NSObject, NotificationsProtocol { } public func notifyTestPopulation(ofErrorMessage errorString: String) { - owsFailDebug("Internal error message: \(errorString)") + owsAssertDebug(expectErrors, "Internal error message: \(errorString)") Logger.warn("Skipping internal error notification: \(errorString)") } diff --git a/SignalServiceKit/src/TestUtils/TestProtocolRunner.swift b/SignalServiceKit/src/TestUtils/TestProtocolRunner.swift index c9d54b59a9..a1291b9d61 100644 --- a/SignalServiceKit/src/TestUtils/TestProtocolRunner.swift +++ b/SignalServiceKit/src/TestUtils/TestProtocolRunner.swift @@ -12,8 +12,12 @@ public struct TestProtocolRunner { public init() { } - public func initialize(senderClient: TestSignalClient, recipientClient: TestSignalClient, transaction: SDSAnyWriteTransaction) throws { - + /// Sets up a session for `senderClient` to send to `recipientClient`, but not vice versa. + /// + /// Messages from `senderClient` will be PreKey messages. + public func initializePreKeys(senderClient: TestSignalClient, + recipientClient: TestSignalClient, + transaction: SDSAnyWriteTransaction) throws { _ = OWSAccountIdFinder.ensureAccountId(forAddress: senderClient.address, transaction: transaction) _ = OWSAccountIdFinder.ensureAccountId(forAddress: recipientClient.address, transaction: transaction) @@ -59,6 +63,16 @@ public struct TestProtocolRunner { ), id: signedPrekeyId, context: transaction) + } + + /// Sets up a session between `senderClient` and `recipientClient`, so that either can talk to the other. + /// + /// Messages between both clients will be "Whisper" / "ciphertext" / "Signal" messages. + public func initialize(senderClient: TestSignalClient, + recipientClient: TestSignalClient, + transaction: SDSAnyWriteTransaction) throws { + + try initializePreKeys(senderClient: senderClient, recipientClient: recipientClient, transaction: transaction) // Then Alice sends a message to Bob so he gets her pre-key as well. let aliceMessage = try encrypt(Data(), @@ -188,11 +202,14 @@ public struct FakeSignalClient: TestSignalClient { /// Represents the local user, backed by the same protocol stores, etc. /// used in the app. public struct LocalSignalClient: TestSignalClient { + public let identity: OWSIdentity - public init() { } + public init(identity: OWSIdentity = .aci) { + self.identity = identity + } public var identityKeyPair: ECKeyPair { - return SSKEnvironment.shared.identityManager.identityKeyPair(for: .aci)! + return SSKEnvironment.shared.identityManager.identityKeyPair(for: identity)! } public var e164Identifier: SignalE164Identifier? { @@ -200,26 +217,29 @@ public struct LocalSignalClient: TestSignalClient { } public var uuid: UUID { - return TSAccountManager.shared.localUuid! + switch identity { + case .aci: return TSAccountManager.shared.localUuid! + case .pni: return TSAccountManager.shared.localPni! + } } public let deviceId: UInt32 = 1 public var sessionStore: SessionStore { - return SSKEnvironment.shared.signalProtocolStore(for: .aci).sessionStore + return SSKEnvironment.shared.signalProtocolStore(for: identity).sessionStore } public var preKeyStore: PreKeyStore { - return SSKEnvironment.shared.signalProtocolStore(for: .aci).preKeyStore + return SSKEnvironment.shared.signalProtocolStore(for: identity).preKeyStore } public var signedPreKeyStore: SignedPreKeyStore { - return SSKEnvironment.shared.signalProtocolStore(for: .aci).signedPreKeyStore + return SSKEnvironment.shared.signalProtocolStore(for: identity).signedPreKeyStore } public var identityKeyStore: IdentityKeyStore { return SSKEnvironment.shared.databaseStorage.read { transaction in - return try! SSKEnvironment.shared.identityManager.store(for: .aci, transaction: transaction) + return try! SSKEnvironment.shared.identityManager.store(for: identity, transaction: transaction) } } } diff --git a/SignalServiceKit/tests/Messages/MessageDecryptionTest.swift b/SignalServiceKit/tests/Messages/MessageDecryptionTest.swift new file mode 100644 index 0000000000..50024fde94 --- /dev/null +++ b/SignalServiceKit/tests/Messages/MessageDecryptionTest.swift @@ -0,0 +1,206 @@ +// +// Copyright (c) 2022 Open Whisper Systems. All rights reserved. +// + +import XCTest +@testable import SignalServiceKit +import LibSignalClient + +class MessageDecryptionTest: SSKBaseTestSwift { + let localE164Identifier = "+13235551234" + let localAci = UUID() + let localPni = UUID() + + let remoteE164Identifier = "+14715355555" + lazy var remoteClient: TestSignalClient = FakeSignalClient.generate(e164Identifier: remoteE164Identifier) + + let localClient = LocalSignalClient() + let localPniClient = LocalSignalClient(identity: .pni) + let runner = TestProtocolRunner() + + let sealedSenderTrustRoot = Curve25519.generateKeyPair() + + // MARK: - Hooks + + override func setUp() { + super.setUp() + + // ensure local client has necessary "registered" state + identityManager.generateNewIdentityKey(for: .aci) + identityManager.generateNewIdentityKey(for: .pni) + tsAccountManager.registerForTests(withLocalNumber: localE164Identifier, uuid: localAci, pni: localPni) + + (notificationsManager as! NoopNotificationsManager).expectErrors = true + (udManager as! OWSUDManagerImpl).trustRoot = try! sealedSenderTrustRoot.ecPublicKey() + } + + // MARK: - Tests + + private let message = "abc" + + private func generateAndDecrypt(type: SSKProtoEnvelopeType, + destinationIdentity: OWSIdentity?, + destinationUuid: UUID? = nil, + handleResult: (Result, SSKProtoEnvelope) -> Void) { + write { transaction in + let localClient: TestSignalClient + if destinationIdentity == .pni { + localClient = self.localPniClient + } else { + localClient = self.localClient + } + + switch type { + case .ciphertext: + try! runner.initialize(senderClient: remoteClient, + recipientClient: localClient, + transaction: transaction) + case .prekeyBundle, .unidentifiedSender: + try! runner.initializePreKeys(senderClient: remoteClient, + recipientClient: localClient, + transaction: transaction) + default: + XCTFail("unsupported envelope type for this test: \(type)") + return + } + + let ciphertext = try! runner.encrypt(message.data(using: .utf8)!, + senderClient: remoteClient, + recipient: localClient.protocolAddress, + context: transaction) + + let envelopeBuilder = SSKProtoEnvelope.builder(timestamp: 0) + envelopeBuilder.setType(type) + if let destinationUuid = destinationUuid { + envelopeBuilder.setDestinationUuid(destinationUuid.uuidString) + } else if destinationIdentity != nil { + envelopeBuilder.setDestinationUuid(localClient.uuidIdentifier) + } + + if type == .unidentifiedSender { + let senderCert = SMKSecretSessionCipherTest.createCertificateFor( + trustRoot: sealedSenderTrustRoot.identityKeyPair, + senderAddress: remoteClient.address, + senderDeviceId: remoteClient.deviceId, + identityKey: remoteClient.identityKeyPair.identityKeyPair.publicKey, + expirationTimestamp: 13337) + let usmc = try! UnidentifiedSenderMessageContent(ciphertext, + from: senderCert, + contentHint: .default, + groupId: []) + envelopeBuilder.setContent(Data(try! sealedSenderEncrypt(usmc, + for: localClient.protocolAddress, + identityStore: remoteClient.identityKeyStore, + context: transaction))) + envelopeBuilder.setServerTimestamp(13336) + } else { + envelopeBuilder.setSourceUuid(remoteClient.uuidIdentifier) + envelopeBuilder.setSourceDevice(remoteClient.deviceId) + envelopeBuilder.setContent(Data(ciphertext.serialize())) + } + + let envelope = try! envelopeBuilder.build() + handleResult(messageDecrypter.decryptEnvelope(envelope, envelopeData: nil, transaction: transaction), + envelope) + } + } + + private func expectDecryptsSuccessfully(type: SSKProtoEnvelopeType, destinationIdentity: OWSIdentity?) { + generateAndDecrypt(type: type, destinationIdentity: destinationIdentity) { result, originalEnvelope in + let decrypted = try! result.get() + XCTAssertNil(decrypted.envelopeData) + XCTAssertEqual(decrypted.identity, destinationIdentity ?? .aci) + XCTAssertNotNil(decrypted.plaintextData) + XCTAssertEqual(String(data: decrypted.plaintextData!, encoding: .utf8), message) + + if type == .unidentifiedSender { + XCTAssertNotIdentical(decrypted.envelope, originalEnvelope) + } else { + XCTAssertIdentical(decrypted.envelope, originalEnvelope) + } + } + } + + private func expectDecryptionFailure(type: SSKProtoEnvelopeType, + destinationIdentity: OWSIdentity?, + destinationUuid: UUID? = nil, + isExpectedError: (Error) -> Bool) { + generateAndDecrypt(type: type, + destinationIdentity: destinationIdentity, + destinationUuid: destinationUuid) { result, _ in + switch result { + case .success: + XCTFail("should not have decrypted successfully") + case .failure(let error): + XCTAssert(isExpectedError(error), "unexpected error: \(error)") + } + } + } + + func testDecryptWhisper() { + expectDecryptsSuccessfully(type: .ciphertext, destinationIdentity: nil) + } + + func testDecryptWhisperExplicitAci() { + expectDecryptsSuccessfully(type: .ciphertext, destinationIdentity: .aci) + } + + func testDecryptWhisperPni() { + expectDecryptionFailure(type: .ciphertext, destinationIdentity: .pni) { error in + if case MessageProcessingError.invalidMessageTypeForDestinationUuid = error { + return true + } + return false + } + } + + func testDecryptPreKey() { + expectDecryptsSuccessfully(type: .prekeyBundle, destinationIdentity: nil) + } + + func testDecryptPreKeyExplicitAci() { + expectDecryptsSuccessfully(type: .prekeyBundle, destinationIdentity: .aci) + } + + func testDecryptPreKeyPni() { + expectDecryptsSuccessfully(type: .prekeyBundle, destinationIdentity: .pni) + } + + func testDecryptPreKeyPniWithAciDestinationUuid() { + expectDecryptionFailure(type: .prekeyBundle, + destinationIdentity: .pni, + destinationUuid: localClient.uuid) { error in + if let error = error as? OWSError { + let underlyingError = error.errorUserInfo[NSUnderlyingErrorKey] + if case SSKSignedPreKeyStore.Error.noPreKeyWithId(_)? = underlyingError { + return true + } + } + return false + } + } + + func testDecryptPreKeyPniWithWrongDestinationUuid() { + expectDecryptionFailure(type: .prekeyBundle, + destinationIdentity: .pni, + destinationUuid: UUID()) { error in + if case MessageProcessingError.wrongDestinationUuid = error { + return true + } + return false + } + } + + func testDecryptSealedSenderPreKey() { + expectDecryptsSuccessfully(type: .unidentifiedSender, destinationIdentity: nil) + } + + func testDecryptSealedSenderPreKeyPni() { + expectDecryptionFailure(type: .unidentifiedSender, destinationIdentity: .pni) { error in + if case MessageProcessingError.invalidMessageTypeForDestinationUuid = error { + return true + } + return false + } + } +} diff --git a/SignalServiceKit/tests/Messages/SMKSecretSessionCipherTest.swift b/SignalServiceKit/tests/Messages/SMKSecretSessionCipherTest.swift index e5e1c37403..2b5691e3fa 100644 --- a/SignalServiceKit/tests/Messages/SMKSecretSessionCipherTest.swift +++ b/SignalServiceKit/tests/Messages/SMKSecretSessionCipherTest.swift @@ -27,11 +27,11 @@ class SMKSecretSessionCipherTest: SSKBaseTestSwift { let trustRoot = IdentityKeyPair.generate() // SenderCertificate senderCertificate = createCertificateFor(trustRoot, "+14151111111", 1, aliceStore.getIdentityKeyPair().getPublicKey().getPublicKey(), 31337); - let senderCertificate = createCertificateFor(trustRoot: trustRoot, - senderAddress: aliceMockClient.address, - senderDeviceId: UInt32(aliceMockClient.deviceId), - identityKey: aliceMockClient.identityKeyPair.publicKey, - expirationTimestamp: 31337) + let senderCertificate = Self.createCertificateFor(trustRoot: trustRoot, + senderAddress: aliceMockClient.address, + senderDeviceId: UInt32(aliceMockClient.deviceId), + identityKey: aliceMockClient.identityKeyPair.publicKey, + expirationTimestamp: 31337) // SecretSessionCipher aliceCipher = new SecretSessionCipher(aliceStore); let aliceCipher: SMKSecretSessionCipher = try! aliceMockClient.createSecretSessionCipher() @@ -78,11 +78,11 @@ class SMKSecretSessionCipherTest: SSKBaseTestSwift { let trustRoot = IdentityKeyPair.generate() let falseTrustRoot = IdentityKeyPair.generate() // SenderCertificate senderCertificate = createCertificateFor(falseTrustRoot, "+14151111111", 1, aliceStore.getIdentityKeyPair().getPublicKey().getPublicKey(), 31337); - let senderCertificate = createCertificateFor(trustRoot: falseTrustRoot, - senderAddress: aliceMockClient.address, - senderDeviceId: UInt32(aliceMockClient.deviceId), - identityKey: aliceMockClient.identityKeyPair.publicKey, - expirationTimestamp: 31337) + let senderCertificate = Self.createCertificateFor(trustRoot: falseTrustRoot, + senderAddress: aliceMockClient.address, + senderDeviceId: UInt32(aliceMockClient.deviceId), + identityKey: aliceMockClient.identityKeyPair.publicKey, + expirationTimestamp: 31337) // SecretSessionCipher aliceCipher = new SecretSessionCipher(aliceStore); let aliceCipher: SMKSecretSessionCipher = try! aliceMockClient.createSecretSessionCipher() @@ -151,11 +151,11 @@ class SMKSecretSessionCipherTest: SSKBaseTestSwift { let trustRoot = IdentityKeyPair.generate() // SenderCertificate senderCertificate = createCertificateFor(trustRoot, "+14151111111", 1, aliceStore.getIdentityKeyPair().getPublicKey().getPublicKey(), 31337); - let senderCertificate = createCertificateFor(trustRoot: trustRoot, - senderAddress: aliceMockClient.address, - senderDeviceId: UInt32(aliceMockClient.deviceId), - identityKey: aliceMockClient.identityKeyPair.publicKey, - expirationTimestamp: 31337) + let senderCertificate = Self.createCertificateFor(trustRoot: trustRoot, + senderAddress: aliceMockClient.address, + senderDeviceId: UInt32(aliceMockClient.deviceId), + identityKey: aliceMockClient.identityKeyPair.publicKey, + expirationTimestamp: 31337) // SecretSessionCipher aliceCipher = new SecretSessionCipher(aliceStore); let aliceCipher: SMKSecretSessionCipher = try! aliceMockClient.createSecretSessionCipher() @@ -227,11 +227,11 @@ class SMKSecretSessionCipherTest: SSKBaseTestSwift { // ECKeyPair randomKeyPair = Curve.generateKeyPair(); let randomKeyPair = IdentityKeyPair.generate() // SenderCertificate senderCertificate = createCertificateFor(trustRoot, "+14151111111", 1, randomKeyPair.getPublicKey(), 31337); - let senderCertificate = createCertificateFor(trustRoot: trustRoot, - senderAddress: aliceMockClient.address, - senderDeviceId: UInt32(aliceMockClient.deviceId), - identityKey: randomKeyPair.publicKey, - expirationTimestamp: 31337) + let senderCertificate = Self.createCertificateFor(trustRoot: trustRoot, + senderAddress: aliceMockClient.address, + senderDeviceId: UInt32(aliceMockClient.deviceId), + identityKey: randomKeyPair.publicKey, + expirationTimestamp: 31337) // SecretSessionCipher aliceCipher = new SecretSessionCipher(aliceStore); let aliceCipher: SMKSecretSessionCipher = try! aliceMockClient.createSecretSessionCipher() @@ -275,7 +275,7 @@ class SMKSecretSessionCipherTest: SSKBaseTestSwift { initializeSessions(aliceMockClient: aliceMockClient, bobMockClient: bobMockClient) let trustRoot = IdentityKeyPair.generate() - let senderCertificate = createCertificateFor( + let senderCertificate = Self.createCertificateFor( trustRoot: trustRoot, senderAddress: aliceMockClient.address, senderDeviceId: UInt32(aliceMockClient.deviceId), @@ -334,7 +334,7 @@ class SMKSecretSessionCipherTest: SSKBaseTestSwift { initializeSessions(aliceMockClient: aliceMockClient, bobMockClient: bobMockClient) let trustRoot = IdentityKeyPair.generate() - let senderCertificate = createCertificateFor( + let senderCertificate = Self.createCertificateFor( trustRoot: trustRoot, senderAddress: aliceMockClient.address, senderDeviceId: UInt32(aliceMockClient.deviceId), @@ -406,11 +406,11 @@ class SMKSecretSessionCipherTest: SSKBaseTestSwift { // private SenderCertificate createCertificateFor(ECKeyPair trustRoot, String sender, int deviceId, ECPublicKey identityKey, long expires) // throws InvalidKeyException, InvalidCertificateException, InvalidProtocolBufferException { - private func createCertificateFor(trustRoot: IdentityKeyPair, - senderAddress: SignalServiceAddress, - senderDeviceId: UInt32, - identityKey: PublicKey, - expirationTimestamp: UInt64) -> SenderCertificate { + internal static func createCertificateFor(trustRoot: IdentityKeyPair, + senderAddress: SignalServiceAddress, + senderDeviceId: UInt32, + identityKey: PublicKey, + expirationTimestamp: UInt64) -> SenderCertificate { let serverKey = IdentityKeyPair.generate() let serverCertificate = try! ServerCertificate(keyId: 1, publicKey: serverKey.publicKey, diff --git a/SignalServiceKit/tests/SSKBaseTestSwift.swift b/SignalServiceKit/tests/SSKBaseTestSwift.swift index 24000b4b1e..7d4835a3a0 100644 --- a/SignalServiceKit/tests/SSKBaseTestSwift.swift +++ b/SignalServiceKit/tests/SSKBaseTestSwift.swift @@ -1,5 +1,5 @@ // -// Copyright (c) 2021 Open Whisper Systems. All rights reserved. +// Copyright (c) 2022 Open Whisper Systems. All rights reserved. // import XCTest @@ -38,11 +38,11 @@ public class SSKBaseTestSwift: XCTestCase { } @objc - public func read(_ block: @escaping (SDSAnyReadTransaction) -> Void) { + public func read(_ block: (SDSAnyReadTransaction) -> Void) { return databaseStorage.read(block: block) } - public func write(_ block: @escaping (SDSAnyWriteTransaction) -> T) -> T { + public func write(_ block: (SDSAnyWriteTransaction) -> T) -> T { return databaseStorage.write(block: block) }