Add unit tests for PhoneVerificationTokenManager

This commit is contained in:
Jon Chambers 2026-04-14 12:01:15 -04:00 committed by Jon Chambers
parent 9e6cbe8f82
commit c02667e2e4
3 changed files with 303 additions and 346 deletions

View File

@ -0,0 +1,240 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.i18n.phonenumbers.PhoneNumberUtil;
import io.grpc.Status;
import jakarta.ws.rs.BadRequestException;
import jakarta.ws.rs.ForbiddenException;
import jakarta.ws.rs.NotAuthorizedException;
import jakarta.ws.rs.ServerErrorException;
import jakarta.ws.rs.container.ContainerRequestContext;
import java.util.Base64;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.entities.RegistrationServiceSession;
import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
import org.whispersystems.textsecuregcm.spam.RegistrationRecoveryChecker;
import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
class PhoneVerificationTokenManagerTest {
private RegistrationServiceClient registrationServiceClient;
private RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager;
private RegistrationRecoveryChecker registrationRecoveryChecker;
private PhoneVerificationTokenManager phoneVerificationTokenManager;
private static final String PHONE_NUMBER = PhoneNumberUtil.getInstance()
.format(PhoneNumberUtil.getInstance().getExampleNumber("US"), PhoneNumberUtil.PhoneNumberFormat.E164);
private static final UUID PHONE_NUMBER_IDENTIFIER = UUID.randomUUID();
record PhoneVerificationRequest(String sessionId, byte[] recoveryPassword) implements org.whispersystems.textsecuregcm.entities.PhoneVerificationRequest {
static PhoneVerificationRequest forSessionId(final byte[] sessionId) {
return new PhoneVerificationRequest(Base64.getUrlEncoder().encodeToString(sessionId), null);
}
static PhoneVerificationRequest forRecoveryPassword(final byte[] recoveryPassword) {
return new PhoneVerificationRequest(null, recoveryPassword);
}
}
@BeforeEach
void setUp() {
final PhoneNumberIdentifiers phoneNumberIdentifiers = mock(PhoneNumberIdentifiers.class);
when(phoneNumberIdentifiers.getPhoneNumberIdentifier(PHONE_NUMBER))
.thenReturn(CompletableFuture.completedFuture(PHONE_NUMBER_IDENTIFIER));
registrationServiceClient = mock(RegistrationServiceClient.class);
registrationRecoveryPasswordsManager = mock(RegistrationRecoveryPasswordsManager.class);
registrationRecoveryChecker = mock(RegistrationRecoveryChecker.class);
phoneVerificationTokenManager = new PhoneVerificationTokenManager(phoneNumberIdentifiers,
registrationServiceClient,
registrationRecoveryPasswordsManager,
registrationRecoveryChecker);
}
@Nested
class SessionBasedVerification {
@Test
void verify() {
final byte[] sessionId = TestRandomUtil.nextBytes(16);
when(registrationServiceClient.getSession(eq(sessionId), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(
new RegistrationServiceSession(sessionId, PHONE_NUMBER, true, null, null, null, 0))));
assertDoesNotThrow(() -> phoneVerificationTokenManager.verify(mock(ContainerRequestContext.class),
PHONE_NUMBER,
PhoneVerificationRequest.forSessionId(sessionId)));
}
@Test
void verifySessionNotFound() {
final byte[] sessionId = TestRandomUtil.nextBytes(16);
when(registrationServiceClient.getSession(eq(sessionId), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
assertThrows(NotAuthorizedException.class, () -> phoneVerificationTokenManager.verify(mock(ContainerRequestContext.class),
PHONE_NUMBER,
PhoneVerificationRequest.forSessionId(sessionId)));
}
@Test
void verifyNumberMismatch() {
final byte[] sessionId = TestRandomUtil.nextBytes(16);
when(registrationServiceClient.getSession(eq(sessionId), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(
new RegistrationServiceSession(sessionId, PHONE_NUMBER + "0", true, null, null, null, 0))));
assertThrows(BadRequestException.class, () -> phoneVerificationTokenManager.verify(mock(ContainerRequestContext.class),
PHONE_NUMBER,
PhoneVerificationRequest.forSessionId(sessionId)));
}
@Test
void verifySessionNotVerified() {
final byte[] sessionId = TestRandomUtil.nextBytes(16);
when(registrationServiceClient.getSession(eq(sessionId), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(
new RegistrationServiceSession(sessionId, PHONE_NUMBER, false, null, null, null, 0))));
assertThrows(NotAuthorizedException.class, () -> phoneVerificationTokenManager.verify(mock(ContainerRequestContext.class),
PHONE_NUMBER,
PhoneVerificationRequest.forSessionId(sessionId)));
}
@ParameterizedTest
@MethodSource
void verifyRegistrationServiceClientException(final Throwable registrationServiceClientException,
final Class<Throwable> expectedExceptionClass) throws ExecutionException, InterruptedException, TimeoutException {
@SuppressWarnings("unchecked") final CompletableFuture<Optional<RegistrationServiceSession>> mockFuture
= mock(CompletableFuture.class);
when(mockFuture.get(anyLong(), any())).thenThrow(registrationServiceClientException);
when(registrationServiceClient.getSession(any(), any())).thenReturn(mockFuture);
assertThrows(expectedExceptionClass, () -> phoneVerificationTokenManager.verify(mock(ContainerRequestContext.class),
PHONE_NUMBER,
PhoneVerificationRequest.forSessionId(TestRandomUtil.nextBytes(16))));
}
private static List<Arguments> verifyRegistrationServiceClientException() {
return List.of(
Arguments.arguments(new ExecutionException(Status.INVALID_ARGUMENT.asRuntimeException()), BadRequestException.class),
Arguments.arguments(new ExecutionException(Status.RESOURCE_EXHAUSTED.asRuntimeException()), ServerErrorException.class),
Arguments.arguments(new CancellationException(), ServerErrorException.class),
Arguments.arguments(new TimeoutException(), ServerErrorException.class)
);
}
}
@Nested
class PasswordBasedVerification {
@Test
void verify() {
final ContainerRequestContext containerRequestContext = mock(ContainerRequestContext.class);
final byte[] recoveryPassword = TestRandomUtil.nextBytes(16);
when(registrationRecoveryChecker.checkRegistrationRecoveryAttempt(containerRequestContext, PHONE_NUMBER))
.thenReturn(true);
when(registrationRecoveryPasswordsManager.verify(PHONE_NUMBER_IDENTIFIER, recoveryPassword))
.thenReturn(CompletableFuture.completedFuture(true));
assertDoesNotThrow(() -> phoneVerificationTokenManager.verify(containerRequestContext,
PHONE_NUMBER,
PhoneVerificationRequest.forRecoveryPassword(recoveryPassword)));
}
@Test
void verifyRecoveryCheckerDeclined() {
final ContainerRequestContext containerRequestContext = mock(ContainerRequestContext.class);
final byte[] recoveryPassword = TestRandomUtil.nextBytes(16);
when(registrationRecoveryChecker.checkRegistrationRecoveryAttempt(containerRequestContext, PHONE_NUMBER))
.thenReturn(false);
when(registrationRecoveryPasswordsManager.verify(PHONE_NUMBER_IDENTIFIER, recoveryPassword))
.thenReturn(CompletableFuture.completedFuture(true));
assertThrows(ForbiddenException.class, () -> phoneVerificationTokenManager.verify(containerRequestContext,
PHONE_NUMBER,
PhoneVerificationRequest.forRecoveryPassword(recoveryPassword)));
}
@Test
void verifyRecoveryPasswordNotVerified() {
final ContainerRequestContext containerRequestContext = mock(ContainerRequestContext.class);
final byte[] recoveryPassword = TestRandomUtil.nextBytes(16);
when(registrationRecoveryChecker.checkRegistrationRecoveryAttempt(containerRequestContext, PHONE_NUMBER))
.thenReturn(true);
when(registrationRecoveryPasswordsManager.verify(PHONE_NUMBER_IDENTIFIER, recoveryPassword))
.thenReturn(CompletableFuture.completedFuture(false));
assertThrows(ForbiddenException.class, () -> phoneVerificationTokenManager.verify(containerRequestContext,
PHONE_NUMBER,
PhoneVerificationRequest.forRecoveryPassword(recoveryPassword)));
}
@ParameterizedTest
@MethodSource
void verifyRecoveryPasswordManagerException(final Throwable recoveryPasswordManagerException)
throws ExecutionException, InterruptedException, TimeoutException {
final ContainerRequestContext containerRequestContext = mock(ContainerRequestContext.class);
final byte[] recoveryPassword = TestRandomUtil.nextBytes(16);
when(registrationRecoveryChecker.checkRegistrationRecoveryAttempt(containerRequestContext, PHONE_NUMBER))
.thenReturn(true);
@SuppressWarnings("unchecked") final CompletableFuture<Boolean> mockFuture = mock(CompletableFuture.class);
when(mockFuture.get(anyLong(), any())).thenThrow(recoveryPasswordManagerException);
when(registrationRecoveryPasswordsManager.verify(PHONE_NUMBER_IDENTIFIER, recoveryPassword))
.thenReturn(mockFuture);
assertThrows(ServerErrorException.class, () -> phoneVerificationTokenManager.verify(containerRequestContext,
PHONE_NUMBER,
PhoneVerificationRequest.forRecoveryPassword(recoveryPassword)));
}
private static List<Throwable> verifyRecoveryPasswordManagerException() {
return List.of(new ExecutionException(new RuntimeException()), new TimeoutException());
}
}
}

View File

@ -25,6 +25,10 @@ import com.google.i18n.phonenumbers.Phonenumber;
import io.dropwizard.auth.AuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import jakarta.ws.rs.BadRequestException;
import jakarta.ws.rs.ForbiddenException;
import jakarta.ws.rs.NotAuthorizedException;
import jakarta.ws.rs.ServerErrorException;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.client.Entity;
import jakarta.ws.rs.client.Invocation;
@ -44,7 +48,6 @@ import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nullable;
@ -74,7 +77,7 @@ import org.whispersystems.textsecuregcm.entities.ChangeNumberRequest;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.PhoneNumberDiscoverabilityRequest;
import org.whispersystems.textsecuregcm.entities.RegistrationServiceSession;
import org.whispersystems.textsecuregcm.entities.PhoneVerificationRequest;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
@ -82,15 +85,11 @@ import org.whispersystems.textsecuregcm.mappers.ImpossiblePhoneNumberExceptionMa
import org.whispersystems.textsecuregcm.mappers.NonNormalizedPhoneNumberExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.push.MessageTooLargeException;
import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
import org.whispersystems.textsecuregcm.spam.RegistrationRecoveryChecker;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountBadge;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ChangeNumberManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
@ -100,8 +99,6 @@ import org.whispersystems.textsecuregcm.util.Util;
@ExtendWith(DropwizardExtensionsSupport.class)
class AccountControllerV2Test {
private static final long SESSION_EXPIRATION_SECONDS = Duration.ofMinutes(10).toSeconds();
private static final ECKeyPair IDENTITY_KEY_PAIR = ECKeyPair.generate();
private static final IdentityKey IDENTITY_KEY = new IdentityKey(IDENTITY_KEY_PAIR.getPublicKey());
@ -111,15 +108,11 @@ class AccountControllerV2Test {
private final AccountsManager accountsManager = mock(AccountsManager.class);
private final ChangeNumberManager changeNumberManager = mock(ChangeNumberManager.class);
private final PhoneNumberIdentifiers phoneNumberIdentifiers = mock(PhoneNumberIdentifiers.class);
private final RegistrationServiceClient registrationServiceClient = mock(RegistrationServiceClient.class);
private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager = mock(
RegistrationRecoveryPasswordsManager.class);
private final PhoneVerificationTokenManager phoneVerificationTokenManager = mock(PhoneVerificationTokenManager.class);
private final RegistrationLockVerificationManager registrationLockVerificationManager = mock(
RegistrationLockVerificationManager.class);
private final RateLimiters rateLimiters = mock(RateLimiters.class);
private final RateLimiter registrationLimiter = mock(RateLimiter.class);
private final RegistrationRecoveryChecker registrationRecoveryChecker = mock(RegistrationRecoveryChecker.class);
private final ResourceExtension resources = ResourceExtension.builder()
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
@ -131,9 +124,7 @@ class AccountControllerV2Test {
.setMapper(SystemMapper.jsonMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(
new AccountControllerV2(accountsManager, changeNumberManager,
new PhoneVerificationTokenManager(phoneNumberIdentifiers, registrationServiceClient,
registrationRecoveryPasswordsManager, registrationRecoveryChecker),
new AccountControllerV2(accountsManager, changeNumberManager, phoneVerificationTokenManager,
registrationLockVerificationManager, rateLimiters))
.build();
@ -142,6 +133,8 @@ class AccountControllerV2Test {
@BeforeEach
void setUp() throws Exception {
reset(changeNumberManager);
when(rateLimiters.getRegistrationLimiter()).thenReturn(registrationLimiter);
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
@ -173,16 +166,16 @@ class AccountControllerV2Test {
return updatedAccount;
});
when(phoneVerificationTokenManager.verify(any(), any(), any())).thenAnswer(invocation -> {
final PhoneVerificationRequest request = invocation.getArgument(2);
return request.verificationType();
});
}
@Test
void changeNumberSuccess() throws Exception {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(CompletableFuture.completedFuture(
Optional.of(new RegistrationServiceSession(new byte[16], NEW_NUMBER, true, null, null, null,
SESSION_EXPIRATION_SECONDS))));
final AccountIdentityResponse accountIdentityResponse =
resources.getJerseyTest()
.target("/v2/accounts/number")
@ -295,10 +288,6 @@ class AccountControllerV2Test {
@ParameterizedTest
@MethodSource
void invalidRegistrationId(final Integer pniRegistrationId, final int expectedStatusCode) {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(CompletableFuture.completedFuture(
Optional.of(new RegistrationServiceSession(new byte[16], NEW_NUMBER, true, null, null, null,
SESSION_EXPIRATION_SECONDS))));
final ChangeNumberRequest changeNumberRequest = new ChangeNumberRequest(encodeSessionId("session"), null, NEW_NUMBER, "123", IDENTITY_KEY,
Collections.emptyList(),
Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, IDENTITY_KEY_PAIR)),
@ -341,61 +330,35 @@ class AccountControllerV2Test {
}
}
@Test
void registrationServiceTimeout() {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException()));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v2/accounts/number")
.request()
.header(HttpHeaders.AUTHORIZATION,
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
try (Response response = request.put(Entity.json(requestJson("sessionId", NEW_NUMBER)))) {
assertEquals(HttpStatus.SC_SERVICE_UNAVAILABLE, response.getStatus());
}
}
@ParameterizedTest
@MethodSource
void registrationServiceSessionCheck(@Nullable final RegistrationServiceSession session, final int expectedStatus,
final String message) {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.ofNullable(session)));
void phoneVerificationException(final Exception exception, final int expectedStatus) throws InterruptedException {
doThrow(exception)
.when(phoneVerificationTokenManager).verify(any(), any(), any());
final Invocation.Builder request = resources.getJerseyTest()
.target("/v2/accounts/number")
.request()
.header(HttpHeaders.AUTHORIZATION,
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
try (Response response = request.put(Entity.json(requestJson("sessionId", NEW_NUMBER)))) {
assertEquals(expectedStatus, response.getStatus(), message);
assertEquals(expectedStatus, response.getStatus());
}
}
static Stream<Arguments> registrationServiceSessionCheck() {
return Stream.of(
Arguments.of(null, 401, "session not found"),
Arguments.of(new RegistrationServiceSession(new byte[16], "+18005551234", false, null, null, null,
SESSION_EXPIRATION_SECONDS), 400,
"session number mismatch"),
Arguments.of(
new RegistrationServiceSession(new byte[16], NEW_NUMBER, false, null, null, null,
SESSION_EXPIRATION_SECONDS),
401,
"session not verified")
private static List<Arguments> phoneVerificationException() {
return List.of(
Arguments.argumentSet("Bad request", new BadRequestException(), HttpStatus.SC_BAD_REQUEST),
Arguments.argumentSet("Not authorized", new NotAuthorizedException("test"), HttpStatus.SC_UNAUTHORIZED),
Arguments.argumentSet("Forbidden", new ForbiddenException(), HttpStatus.SC_FORBIDDEN),
Arguments.argumentSet("Unexpected exception", new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE), HttpStatus.SC_SERVICE_UNAVAILABLE)
);
}
@ParameterizedTest
@EnumSource(RegistrationLockError.class)
void registrationLock(final RegistrationLockError error) throws Exception {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(
CompletableFuture.completedFuture(
Optional.of(new RegistrationServiceSession(new byte[16], NEW_NUMBER, true, null, null, null,
SESSION_EXPIRATION_SECONDS))));
when(accountsManager.getByE164(any())).thenReturn(Optional.of(mock(Account.class)));
final Exception e = switch (error) {
@ -417,20 +380,15 @@ class AccountControllerV2Test {
@Test
void recoveryPasswordManagerVerificationTrue() throws Exception {
when(phoneNumberIdentifiers.getPhoneNumberIdentifier(any()))
.thenReturn(CompletableFuture.completedFuture(UUID.randomUUID()));
when(registrationRecoveryPasswordsManager.verify(any(), any()))
.thenReturn(CompletableFuture.completedFuture(true));
when(registrationRecoveryChecker.checkRegistrationRecoveryAttempt(any(), any()))
.thenReturn(true);
final Invocation.Builder request = resources.getJerseyTest()
.target("/v2/accounts/number")
.request()
.header(HttpHeaders.AUTHORIZATION,
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
final byte[] recoveryPassword = new byte[32];
try (Response response = request.put(Entity.json(requestJsonRecoveryPassword(recoveryPassword, NEW_NUMBER)))) {
try (final Response response = request.put(Entity.json(requestJsonRecoveryPassword(recoveryPassword, NEW_NUMBER)))) {
assertEquals(200, response.getStatus());
final AccountIdentityResponse accountIdentityResponse = response.readEntity(AccountIdentityResponse.class);
@ -444,68 +402,10 @@ class AccountControllerV2Test {
}
}
@Test
void recoveryPasswordManagerVerificationFalse() {
when(registrationRecoveryPasswordsManager.verify(any(), any()))
.thenReturn(CompletableFuture.completedFuture(false));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v2/accounts/number")
.request()
.header(HttpHeaders.AUTHORIZATION,
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
try (Response response = request.put(Entity.json(requestJsonRecoveryPassword(new byte[32], NEW_NUMBER)))) {
assertEquals(403, response.getStatus());
}
}
@Test
void registrationRecoveryCheckerAllowsAttempt() {
when(phoneNumberIdentifiers.getPhoneNumberIdentifier(any()))
.thenReturn(CompletableFuture.completedFuture(UUID.randomUUID()));
when(registrationRecoveryChecker.checkRegistrationRecoveryAttempt(any(), any())).thenReturn(true);
when(registrationRecoveryPasswordsManager.verify(any(), any()))
.thenReturn(CompletableFuture.completedFuture(true));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v2/accounts/number")
.request()
.header(HttpHeaders.AUTHORIZATION,
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
final byte[] recoveryPassword = new byte[32];
try (Response response = request.put(Entity.json(requestJsonRecoveryPassword(recoveryPassword, NEW_NUMBER)))) {
assertEquals(200, response.getStatus());
}
}
@Test
void registrationRecoveryCheckerDisallowsAttempt() {
when(registrationRecoveryChecker.checkRegistrationRecoveryAttempt(any(), any())).thenReturn(false);
when(registrationRecoveryPasswordsManager.verify(any(), any()))
.thenReturn(CompletableFuture.completedFuture(true));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v2/accounts/number")
.request()
.header(HttpHeaders.AUTHORIZATION,
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
final byte[] recoveryPassword = new byte[32];
try (Response response = request.put(Entity.json(requestJsonRecoveryPassword(recoveryPassword, NEW_NUMBER)))) {
assertEquals(403, response.getStatus());
}
}
@Test
void deviceMessageTooLarge() throws Exception {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(CompletableFuture.completedFuture(
Optional.of(new RegistrationServiceSession(new byte[16], NEW_NUMBER, true, null, null, null,
SESSION_EXPIRATION_SECONDS))));
reset(changeNumberManager);
when(changeNumberManager.changeNumber(any(), any(), any(), any(), any(), any(), any(), any()))
.thenThrow(MessageTooLargeException.class);
doThrow(MessageTooLargeException.class)
.when(changeNumberManager).changeNumber(any(), any(), any(), any(), any(), any(), any(), any());
try (final Response response = resources.getJerseyTest()
.target("/v2/accounts/number")

View File

@ -6,6 +6,7 @@
package org.whispersystems.textsecuregcm.controllers;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
@ -19,6 +20,10 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.i18n.phonenumbers.PhoneNumberUtil;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import jakarta.ws.rs.BadRequestException;
import jakarta.ws.rs.ForbiddenException;
import jakarta.ws.rs.NotAuthorizedException;
import jakarta.ws.rs.ServerErrorException;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.client.Entity;
import jakarta.ws.rs.client.Invocation;
@ -26,7 +31,6 @@ import jakarta.ws.rs.core.HttpHeaders;
import jakarta.ws.rs.core.Response;
import java.io.UncheckedIOException;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
@ -37,7 +41,6 @@ import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
import java.util.stream.Stream;
import javax.annotation.Nullable;
@ -68,23 +71,19 @@ import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.PhoneVerificationRequest;
import org.whispersystems.textsecuregcm.entities.RegistrationRequest;
import org.whispersystems.textsecuregcm.entities.RegistrationServiceSession;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.ImpossiblePhoneNumberExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.NonNormalizedPhoneNumberExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
import org.whispersystems.textsecuregcm.spam.RegistrationFraudChecker;
import org.whispersystems.textsecuregcm.spam.RegistrationRecoveryChecker;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DeviceCapability;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.MockUtils;
@ -94,9 +93,6 @@ import org.whispersystems.textsecuregcm.util.TestRandomUtil;
@ExtendWith(DropwizardExtensionsSupport.class)
class RegistrationControllerTest {
private static final long SESSION_EXPIRATION_SECONDS = Duration.ofMinutes(10).toSeconds();
private static final String NUMBER = PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164);
@ -104,15 +100,11 @@ class RegistrationControllerTest {
private static final String REGLOCK = RandomStringUtils.insecure().nextAlphanumeric(64);
private final AccountsManager accountsManager = mock(AccountsManager.class);
private final PhoneNumberIdentifiers phoneNumberIdentifiers = mock(PhoneNumberIdentifiers.class);
private final RegistrationServiceClient registrationServiceClient = mock(RegistrationServiceClient.class);
private final RegistrationLockVerificationManager registrationLockVerificationManager = mock(
RegistrationLockVerificationManager.class);
private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager = mock(
RegistrationRecoveryPasswordsManager.class);
private final RegistrationRecoveryChecker registrationRecoveryChecker = mock(RegistrationRecoveryChecker.class);
private final RegistrationLockVerificationManager registrationLockVerificationManager
= mock(RegistrationLockVerificationManager.class);
private final RateLimiters rateLimiters = mock(RateLimiters.class);
private final RegistrationFraudChecker registrationFraudChecker = mock(RegistrationFraudChecker.class);
private final PhoneVerificationTokenManager phoneVerificationTokenManager = mock(PhoneVerificationTokenManager.class);
private final RateLimiter registrationLimiter = mock(RateLimiter.class);
@ -123,15 +115,12 @@ class RegistrationControllerTest {
.addProvider(new NonNormalizedPhoneNumberExceptionMapper())
.setMapper(SystemMapper.jsonMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(
new RegistrationController(accountsManager,
new PhoneVerificationTokenManager(phoneNumberIdentifiers, registrationServiceClient,
registrationRecoveryPasswordsManager, registrationRecoveryChecker),
registrationLockVerificationManager, rateLimiters, registrationFraudChecker))
.addResource(new RegistrationController(accountsManager, phoneVerificationTokenManager,
registrationLockVerificationManager, rateLimiters, registrationFraudChecker))
.build();
@BeforeEach
void setUp() {
void setUp() throws InterruptedException {
when(rateLimiters.getRegistrationLimiter()).thenReturn(registrationLimiter);
when(accountsManager.update(any(UUID.class), any())).thenAnswer(invocation -> {
@ -144,6 +133,12 @@ class RegistrationControllerTest {
});
reset(registrationFraudChecker);
reset(phoneVerificationTokenManager);
when(phoneVerificationTokenManager.verify(any(), any(), any())).thenAnswer(invocation -> {
final PhoneVerificationRequest request = invocation.getArgument(2);
return request.verificationType();
});
}
@Test
@ -176,12 +171,6 @@ class RegistrationControllerTest {
.request()
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(
CompletableFuture.completedFuture(
Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null,
SESSION_EXPIRATION_SECONDS))));
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(mock(Device.class));
@ -241,163 +230,36 @@ class RegistrationControllerTest {
}
}
@Test
void registrationServiceTimeout() {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException()));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration")
.request()
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
try (Response response = request.post(Entity.json(requestJson("sessionId")))) {
assertEquals(HttpStatus.SC_SERVICE_UNAVAILABLE, response.getStatus());
}
}
@Test
void recoveryPasswordManagerVerificationFailureOrTimeout() {
when(phoneNumberIdentifiers.getPhoneNumberIdentifier(any()))
.thenReturn(CompletableFuture.completedFuture(UUID.randomUUID()));
when(registrationRecoveryChecker.checkRegistrationRecoveryAttempt(any(), any())).thenReturn(true);
when(registrationRecoveryPasswordsManager.verify(any(), any()))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException()));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration")
.request()
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(new byte[32])))) {
assertEquals(HttpStatus.SC_SERVICE_UNAVAILABLE, response.getStatus());
}
}
@ParameterizedTest
@MethodSource
void registrationServiceSessionCheck(@Nullable final RegistrationServiceSession session, final int expectedStatus,
final String message) {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.ofNullable(session)));
void phoneVerificationException(final Exception exception, final int expectedStatus) throws InterruptedException {
doThrow(exception)
.when(phoneVerificationTokenManager).verify(any(), any(), any());
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration")
.request()
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
try (Response response = request.post(Entity.json(requestJson("sessionId")))) {
assertEquals(expectedStatus, response.getStatus(), message);
assertEquals(expectedStatus, response.getStatus());
}
}
static Stream<Arguments> registrationServiceSessionCheck() {
return Stream.of(
Arguments.of(null, 401, "session not found"),
Arguments.of(
new RegistrationServiceSession(new byte[16], "+18005551234", false, null, null, null,
SESSION_EXPIRATION_SECONDS),
400,
"session number mismatch"),
Arguments.of(
new RegistrationServiceSession(new byte[16], NUMBER, false, null, null, null, SESSION_EXPIRATION_SECONDS),
401,
"session not verified")
private static List<Arguments> phoneVerificationException() {
return List.of(
Arguments.argumentSet("Bad request", new BadRequestException(), HttpStatus.SC_BAD_REQUEST),
Arguments.argumentSet("Not authorized", new NotAuthorizedException("test"), HttpStatus.SC_UNAUTHORIZED),
Arguments.argumentSet("Forbidden", new ForbiddenException(), HttpStatus.SC_FORBIDDEN),
Arguments.argumentSet("Unexpected exception", new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE), HttpStatus.SC_SERVICE_UNAVAILABLE)
);
}
@Test
void recoveryPasswordManagerVerificationTrue() throws InterruptedException {
when(phoneNumberIdentifiers.getPhoneNumberIdentifier(any()))
.thenReturn(CompletableFuture.completedFuture(UUID.randomUUID()));
when(registrationRecoveryChecker.checkRegistrationRecoveryAttempt(any(), any())).thenReturn(true);
when(registrationRecoveryPasswordsManager.verify(any(), any()))
.thenReturn(CompletableFuture.completedFuture(true));
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(mock(Device.class));
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any()))
.thenReturn(account);
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration")
.request()
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
final byte[] recoveryPassword = new byte[32];
try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(recoveryPassword)))) {
assertEquals(200, response.getStatus());
}
}
@Test
void recoveryPasswordManagerVerificationFalse() {
when(registrationRecoveryPasswordsManager.verify(any(), any()))
.thenReturn(CompletableFuture.completedFuture(false));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration")
.request()
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(new byte[32])))) {
assertEquals(403, response.getStatus());
}
}
@Test
void registrationRecoveryCheckerAllowsAttempt() throws InterruptedException {
when(phoneNumberIdentifiers.getPhoneNumberIdentifier(any()))
.thenReturn(CompletableFuture.completedFuture(UUID.randomUUID()));
when(registrationRecoveryChecker.checkRegistrationRecoveryAttempt(any(), any())).thenReturn(true);
when(registrationRecoveryPasswordsManager.verify(any(), any()))
.thenReturn(CompletableFuture.completedFuture(true));
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(mock(Device.class));
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any()))
.thenReturn(account);
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration")
.request()
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
final byte[] recoveryPassword = new byte[32];
try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(recoveryPassword)))) {
assertEquals(200, response.getStatus());
}
}
@Test
void registrationRecoveryCheckerDisallowsAttempt() throws InterruptedException {
when(registrationRecoveryChecker.checkRegistrationRecoveryAttempt(any(), any())).thenReturn(false);
when(registrationRecoveryPasswordsManager.verify(any(), any()))
.thenReturn(CompletableFuture.completedFuture(true));
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(mock(Device.class));
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any()))
.thenReturn(account);
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration")
.request()
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
final byte[] recoveryPassword = new byte[32];
try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(recoveryPassword)))) {
assertEquals(403, response.getStatus());
}
}
@CartesianTest
@CartesianTest.MethodFactory("registrationLockAndDeviceTransfer")
void registrationLockAndDeviceTransfer(
final boolean deviceTransferSupported,
@Nullable final RegistrationLockError error)
throws Exception {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(
CompletableFuture.completedFuture(
Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null,
SESSION_EXPIRATION_SECONDS))));
@Nullable final RegistrationLockError error) throws Exception {
final Account account = mock(Account.class);
when(accountsManager.getByE164(any())).thenReturn(Optional.of(account));
@ -448,16 +310,11 @@ class RegistrationControllerTest {
.format(PhoneNumberUtil.getInstance().getExampleNumber("BJ"), PhoneNumberUtil.PhoneNumberFormat.E164);
final String oldFormatBeninNumber = newFormatBeninNumber.replaceFirst("01", "");
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(
CompletableFuture.completedFuture(
Optional.of(new RegistrationServiceSession(new byte[16], newFormatBeninNumber, true, null, null, null,
SESSION_EXPIRATION_SECONDS))));
assertNotEquals(newFormatBeninNumber, oldFormatBeninNumber);
final Account account = mock(Account.class);
when(accountsManager.getByE164(oldFormatBeninNumber)).thenReturn(Optional.of(account));
when(accountsManager.getByE164(newFormatBeninNumber)).thenReturn(Optional.empty());
when(account.hasCapability(DeviceCapability.TRANSFER)).thenReturn(false);
doThrow(new WebApplicationException(RegistrationLockError.MISMATCH.getExpectedStatus()))
.when(registrationLockVerificationManager).verifyRegistrationLock(any(), any(), any(), any(), any());
@ -466,13 +323,12 @@ class RegistrationControllerTest {
.target("/v1/registration")
.request()
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(newFormatBeninNumber, PASSWORD));
try (final Response response = request.post(Entity.json(requestJson("sessionId")))) {
assertEquals(RegistrationLockError.MISMATCH.getExpectedStatus(), response.getStatus());
}
}
@ParameterizedTest
@CsvSource({
"false, false, false, 200",
@ -483,11 +339,6 @@ class RegistrationControllerTest {
})
void deviceTransferAvailable(final boolean existingAccount, final boolean transferSupported,
final boolean skipDeviceTransfer, final int expectedStatus) throws Exception {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(
CompletableFuture.completedFuture(
Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null,
SESSION_EXPIRATION_SECONDS))));
final Optional<Account> maybeAccount;
if (existingAccount) {
@ -517,12 +368,6 @@ class RegistrationControllerTest {
// this is functionally the same as deviceTransferAvailable(existingAccount=false)
@Test
void registrationSuccess() throws Exception {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(
CompletableFuture.completedFuture(
Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null,
SESSION_EXPIRATION_SECONDS))));
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(mock(Device.class));
@ -546,12 +391,6 @@ class RegistrationControllerTest {
@ParameterizedTest
@MethodSource
void atomicAccountCreationConflictingChannel(final RegistrationRequest conflictingChannelRequest) {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(
CompletableFuture.completedFuture(
Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null,
SESSION_EXPIRATION_SECONDS))));
try (final Response response = resources.getJerseyTest()
.target("/v1/registration")
.request()
@ -637,12 +476,6 @@ class RegistrationControllerTest {
@ParameterizedTest
@MethodSource
void atomicAccountCreationPartialSignedPreKeys(final RegistrationRequest partialSignedPreKeyRequest) {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(
CompletableFuture.completedFuture(
Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null,
SESSION_EXPIRATION_SECONDS))));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration")
.request()
@ -771,12 +604,6 @@ class RegistrationControllerTest {
final IdentityKey expectedPniIdentityKey,
final DeviceSpec expectedDeviceSpec) throws InterruptedException {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(
CompletableFuture.completedFuture(
Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null,
SESSION_EXPIRATION_SECONDS))));
final UUID accountIdentifier = UUID.randomUUID();
final UUID phoneNumberIdentifier = UUID.randomUUID();
final Device device = mock(Device.class);
@ -814,11 +641,6 @@ class RegistrationControllerTest {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void reregistrationFlag(final boolean existingAccount) throws InterruptedException {
final RegistrationServiceSession registrationSession =
new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null, SESSION_EXPIRATION_SECONDS);
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(registrationSession)));
final Optional<Account> maybeAccount = Optional.ofNullable(existingAccount ? mock(Account.class) : null);
when(accountsManager.getByE164(any())).thenReturn(maybeAccount);
@ -832,6 +654,7 @@ class RegistrationControllerTest {
.target("/v1/registration")
.request()
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
try (Response response = request.post(Entity.json(requestJson("sessionId")))) {
assertEquals(200, response.getStatus());
final AccountCreationResponse creationResponse = response.readEntity(AccountCreationResponse.class);
@ -841,12 +664,6 @@ class RegistrationControllerTest {
@Test
void registrationMissingSpqrCapability() throws Exception {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(
CompletableFuture.completedFuture(
Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null,
SESSION_EXPIRATION_SECONDS))));
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(mock(Device.class));