diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AccountAuthenticator.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AccountAuthenticator.java index 71fb66beb..97d32f561 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AccountAuthenticator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AccountAuthenticator.java @@ -142,7 +142,7 @@ public class AccountAuthenticator implements Authenticator { // Receipt credential expirations must be day aligned. Make sure any manually set backupVoucher is also day // aligned - final Account.BackupVoucher newPayment = new Account.BackupVoucher( + final Account.BackupVoucher newPayment = new Account.BackupVoucher( backupVoucher.receiptLevel(), backupVoucher.expiration().truncatedTo(ChronoUnit.DAYS)); final Account.BackupVoucher existingPayment = a.getBackupVoucher(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java index ddd980715..75ae21e3a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java @@ -108,33 +108,26 @@ public class AccountController { public void setGcmRegistrationId(@Auth AuthenticatedDevice auth, @NotNull @Valid GcmRegistrationId registrationId) { - final Account account = accounts.getByAccountIdentifier(auth.accountIdentifier()) - .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); + final String currentGcmId = accounts.getByAccountIdentifier(auth.accountIdentifier()) + .flatMap(account -> account.getDevice(auth.deviceId())) + .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)) + .getGcmId(); - final Device device = account.getDevice(auth.deviceId()) - .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); - - if (Objects.equals(device.getGcmId(), registrationId.gcmRegistrationId())) { + if (Objects.equals(currentGcmId, registrationId.gcmRegistrationId())) { return; } - accounts.updateDevice(account, device.getId(), d -> { - d.setApnId(null); - d.setGcmId(registrationId.gcmRegistrationId()); - d.setFetchesMessages(false); + accounts.updateDevice(auth.accountIdentifier(), auth.deviceId(), device -> { + device.setApnId(null); + device.setGcmId(registrationId.gcmRegistrationId()); + device.setFetchesMessages(false); }); } @DELETE @Path("/gcm/") public void deleteGcmRegistrationId(@Auth AuthenticatedDevice auth) { - final Account account = accounts.getByAccountIdentifier(auth.accountIdentifier()) - .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); - - final Device device = account.getDevice(auth.deviceId()) - .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); - - accounts.updateDevice(account, device.getId(), d -> { + accounts.updateDevice(auth.accountIdentifier(), auth.deviceId(), d -> { d.setGcmId(null); d.setFetchesMessages(false); d.setUserAgent("OWA"); @@ -148,15 +141,9 @@ public class AccountController { public void setApnRegistrationId(@Auth AuthenticatedDevice auth, @NotNull @Valid ApnRegistrationId registrationId) { - final Account account = accounts.getByAccountIdentifier(auth.accountIdentifier()) - .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); - - final Device device = account.getDevice(auth.deviceId()) - .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); - // Unlike FCM tokens, we need current "last updated" timestamps for APNs tokens and so update device records // unconditionally - accounts.updateDevice(account, device.getId(), d -> { + accounts.updateDevice(auth.accountIdentifier(), auth.deviceId(), d -> { d.setApnId(registrationId.apnRegistrationId()); d.setGcmId(null); d.setFetchesMessages(false); @@ -166,20 +153,10 @@ public class AccountController { @DELETE @Path("/apn/") public void deleteApnRegistrationId(@Auth AuthenticatedDevice auth) { - final Account account = accounts.getByAccountIdentifier(auth.accountIdentifier()) - .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); - - final Device device = account.getDevice(auth.deviceId()) - .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); - - accounts.updateDevice(account, device.getId(), d -> { + accounts.updateDevice(auth.accountIdentifier(), auth.deviceId(), d -> { d.setApnId(null); d.setFetchesMessages(false); - if (d.getId() == 1) { - d.setUserAgent("OWI"); - } else { - d.setUserAgent("OWP"); - } + d.setUserAgent(d.isPrimary() ? "OWI" : "OWP"); }); } @@ -189,20 +166,14 @@ public class AccountController { public void setRegistrationLock(@Auth AuthenticatedDevice auth, @NotNull @Valid RegistrationLock accountLock) { final SaltedTokenHash credentials = SaltedTokenHash.generateFor(accountLock.registrationLock()); - final Account account = accounts.getByAccountIdentifier(auth.accountIdentifier()) - .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); - - accounts.update(account, + accounts.update(auth.accountIdentifier(), a -> a.setRegistrationLock(credentials.hash(), credentials.salt())); } @DELETE @Path("/registration_lock") public void removeRegistrationLock(@Auth AuthenticatedDevice auth) { - final Account account = accounts.getByAccountIdentifier(auth.accountIdentifier()) - .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); - - accounts.update(account, a -> a.setRegistrationLock(null, null)); + accounts.update(auth.accountIdentifier(), a -> a.setRegistrationLock(null, null)); } @PUT @@ -240,7 +211,7 @@ public class AccountController { throw new ForbiddenException(); } - accounts.updateDevice(account, targetDeviceId, d -> d.setName(deviceName.deviceName())); + accounts.updateDevice(account.getIdentifier(IdentityType.ACI), targetDeviceId, d -> d.setName(deviceName.deviceName())); } @PUT @@ -252,10 +223,8 @@ public class AccountController { @HeaderParam(HttpHeaders.USER_AGENT) String userAgent, @HeaderParam(HeaderUtils.X_SIGNAL_AGENT) String signalAgent, @NotNull @Valid AccountAttributes attributes) { - final Account account = accounts.getByAccountIdentifier(auth.accountIdentifier()) - .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); - final Account updatedAccount = accounts.update(account, a -> { + final Account updatedAccount = accounts.update(auth.accountIdentifier(), a -> { a.getDevice(auth.deviceId()).ifPresent(d -> { d.setFetchesMessages(attributes.getFetchesMessages()); d.setName(attributes.getName()); @@ -306,10 +275,7 @@ public class AccountController { @ApiResponse(responseCode = "204", description = "Username successfully deleted.", useReturnTypeSchema = true) @ApiResponse(responseCode = "401", description = "Account authentication check failed.") public void deleteUsernameHash(@Auth final AuthenticatedDevice auth) { - final Account account = accounts.getByAccountIdentifier(auth.accountIdentifier()) - .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); - - accounts.clearUsernameHash(account); + accounts.clearUsernameHash(auth.accountIdentifier()); } @PUT @@ -332,9 +298,6 @@ public class AccountController { @Auth final AuthenticatedDevice auth, @NotNull @Valid final ReserveUsernameHashRequest usernameRequest) throws RateLimitExceededException { - final Account account = accounts.getByAccountIdentifier(auth.accountIdentifier()) - .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); - rateLimiters.getUsernameReserveLimiter().validate(auth.accountIdentifier()); for (final byte[] hash : usernameRequest.usernameHashes()) { @@ -345,7 +308,7 @@ public class AccountController { try { final AccountsManager.UsernameReservation reservation = - accounts.reserveUsernameHash(account, usernameRequest.usernameHashes()); + accounts.reserveUsernameHash(auth.accountIdentifier(), usernameRequest.usernameHashes()); return new ReserveUsernameHashResponse(reservation.reservedUsernameHash()); } catch (final UsernameHashNotAvailableException e) { @@ -374,20 +337,17 @@ public class AccountController { @Auth final AuthenticatedDevice auth, @NotNull @Valid final ConfirmUsernameHashRequest confirmRequest) throws RateLimitExceededException { - final Account account = accounts.getByAccountIdentifier(auth.accountIdentifier()) - .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); - try { usernameHashZkProofVerifier.verifyProof(confirmRequest.zkProof(), confirmRequest.usernameHash()); } catch (final BaseUsernameException e) { throw new WebApplicationException(Response.status(422).build()); } - rateLimiters.getUsernameSetLimiter().validate(account.getUuid()); + rateLimiters.getUsernameSetLimiter().validate(auth.accountIdentifier()); try { final Account updatedAccount = accounts.confirmReservedUsernameHash( - account, + auth.accountIdentifier(), confirmRequest.usernameHash(), confirmRequest.encryptedUsername()); @@ -474,7 +434,7 @@ public class AccountController { } else { usernameLinkHandle = UUID.randomUUID(); } - updateUsernameLink(account, usernameLinkHandle, encryptedUsername.usernameLinkEncryptedValue()); + updateUsernameLink(account.getIdentifier(IdentityType.ACI), usernameLinkHandle, encryptedUsername.usernameLinkEncryptedValue()); return new UsernameLinkHandle(usernameLinkHandle); } @@ -494,10 +454,7 @@ public class AccountController { // check ratelimiter for username link operations rateLimiters.forDescriptor(RateLimiters.For.USERNAME_LINK_OPERATION).validate(auth.accountIdentifier()); - final Account account = accounts.getByAccountIdentifier(auth.accountIdentifier()) - .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); - - clearUsernameLink(account); + clearUsernameLink(auth.accountIdentifier()); } @GET @@ -559,24 +516,21 @@ public class AccountController { @DELETE @Path("/me") public void deleteAccount(@Auth AuthenticatedDevice auth) { - final Account account = accounts.getByAccountIdentifier(auth.accountIdentifier()) - .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); - - accounts.delete(account, AccountsManager.DeletionReason.USER_REQUEST); + accounts.delete(auth.accountIdentifier(), AccountsManager.DeletionReason.USER_REQUEST); } - private void clearUsernameLink(final Account account) { - updateUsernameLink(account, null, null); + private void clearUsernameLink(final UUID accountIdentifier) { + updateUsernameLink(accountIdentifier, null, null); } private void updateUsernameLink( - final Account account, + final UUID accountIdentifier, @Nullable final UUID usernameLinkHandle, @Nullable final byte[] encryptedUsername) { if ((encryptedUsername == null) ^ (usernameLinkHandle == null)) { throw new IllegalStateException("Both or neither arguments must be null"); } - accounts.update(account, a -> a.setUsernameLinkDetails(usernameLinkHandle, encryptedUsername)); + accounts.update(accountIdentifier, a -> a.setUsernameLinkDetails(usernameLinkHandle, encryptedUsername)); } private void requireNotAuthenticated(final Optional authenticatedAccount) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2.java index d91b0fcb2..a080183c7 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2.java @@ -179,10 +179,7 @@ public class AccountControllerV2 { @Auth AuthenticatedDevice auth, @NotNull @Valid PhoneNumberDiscoverabilityRequest phoneNumberDiscoverability) { - final Account account = accountsManager.getByAccountIdentifier(auth.accountIdentifier()) - .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); - - accountsManager.update(account, a -> a.setDiscoverableByPhoneNumber( + accountsManager.update(auth.accountIdentifier(), a -> a.setDiscoverableByPhoneNumber( phoneNumberDiscoverability.discoverableByPhoneNumber())); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index 9ed468073..cb893d718 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -159,10 +159,7 @@ public class DeviceController { throw new ForbiddenException(); } - final Account account = accounts.getByAccountIdentifier(auth.accountIdentifier()) - .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); - - accounts.removeDevice(account, deviceId); + accounts.removeDevice(auth.accountIdentifier(), deviceId); } /** @@ -282,20 +279,21 @@ public class DeviceController { } try { - final Pair accountAndDevice = accounts.addDevice(account, new DeviceSpec(deviceAttributes.name(), - authorizationHeader.getPassword(), - signalAgent, - capabilities, - deviceAttributes.registrationId(), - deviceAttributes.phoneNumberIdentityRegistrationId(), - deviceAttributes.fetchesMessages(), - deviceActivationRequest.apnToken(), - deviceActivationRequest.gcmToken(), - deviceActivationRequest.aciSignedPreKey(), - deviceActivationRequest.pniSignedPreKey(), - deviceActivationRequest.aciPqLastResortPreKey(), - deviceActivationRequest.pniPqLastResortPreKey()), - linkDeviceRequest.verificationCode()); + final Pair accountAndDevice = accounts.addDevice(account.getIdentifier(IdentityType.ACI), + new DeviceSpec(deviceAttributes.name(), + authorizationHeader.getPassword(), + signalAgent, + capabilities, + deviceAttributes.registrationId(), + deviceAttributes.phoneNumberIdentityRegistrationId(), + deviceAttributes.fetchesMessages(), + deviceActivationRequest.apnToken(), + deviceActivationRequest.gcmToken(), + deviceActivationRequest.aciSignedPreKey(), + deviceActivationRequest.pniSignedPreKey(), + deviceActivationRequest.aciPqLastResortPreKey(), + deviceActivationRequest.pniPqLastResortPreKey()), + linkDeviceRequest.verificationCode()); return new LinkDeviceResponse( accountAndDevice.first().getIdentifier(IdentityType.ACI), @@ -384,15 +382,8 @@ public class DeviceController { @PUT @Produces(MediaType.APPLICATION_JSON) @Path("/capabilities") - public void setCapabilities(@Auth final AuthenticatedDevice auth, - - @NotNull - final Map capabilities) { - - final Account account = accounts.getByAccountIdentifier(auth.accountIdentifier()) - .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); - - accounts.updateDevice(account, auth.deviceId(), + public void setCapabilities(@Auth final AuthenticatedDevice auth, @NotNull final Map capabilities) { + accounts.updateDevice(auth.accountIdentifier(), auth.deviceId(), d -> d.setCapabilities(DeviceCapabilityAdapter.mapToSet(capabilities))); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DonationController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DonationController.java index 3144309b8..014aa493d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DonationController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DonationController.java @@ -15,16 +15,12 @@ import jakarta.ws.rs.Consumes; import jakarta.ws.rs.POST; import jakarta.ws.rs.Path; import jakarta.ws.rs.Produces; -import jakarta.ws.rs.WebApplicationException; import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.core.Response; import jakarta.ws.rs.core.Response.Status; import java.time.Clock; import java.time.Instant; import java.util.Objects; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionStage; -import java.util.function.Function; import javax.annotation.Nonnull; import org.glassfish.jersey.server.ManagedAsync; import org.signal.libsignal.zkgroup.InvalidInputException; @@ -35,7 +31,6 @@ import org.signal.libsignal.zkgroup.receipts.ServerZkReceiptOperations; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.configuration.BadgesConfiguration; import org.whispersystems.textsecuregcm.entities.RedeemReceiptRequest; -import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountBadge; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.RedeemedReceiptsManager; @@ -121,9 +116,6 @@ public class DonationController { .build(); } - final Account account = accountsManager.getByAccountIdentifier(auth.accountIdentifier()) - .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); - final boolean receiptMatched = redeemedReceiptsManager.put( receiptSerial, receiptExpiration.getEpochSecond(), receiptLevel, auth.accountIdentifier()).join(); if (!receiptMatched) { @@ -133,7 +125,7 @@ public class DonationController { .build(); } - accountsManager.update(account, a -> { + accountsManager.update(auth.accountIdentifier(), a -> { a.addBadge(clock, new AccountBadge(badgeId, receiptExpiration, request.isVisible())); if (request.isPrimary()) { a.makeBadgePrimaryIfExists(clock, badgeId); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java index 54764a190..3adf540b6 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java @@ -215,7 +215,7 @@ public class ProfileController { currentAvatar.ifPresent(profilesManager::deleteAvatar); } - accountsManager.update(account, a -> { + accountsManager.update(account.getIdentifier(IdentityType.ACI), a -> { final List updatedBadges = request.badges() .map(badges -> ProfileHelper.mergeBadgeIdsWithExistingAccountBadges(clock, badgeConfigurationMap, badges, a.getBadges())) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/AccountsGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/AccountsGrpcService.java index b0392ac51..7c3775aea 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/AccountsGrpcService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/AccountsGrpcService.java @@ -97,8 +97,8 @@ public class AccountsGrpcService extends SimpleAccountsGrpc.AccountsImplBase { @Override public DeleteAccountResponse deleteAccount(final DeleteAccountRequest request) { - accountsManager.delete(getAuthenticatedAccount(AuthenticationUtil.requireAuthenticatedPrimaryDevice()), - AccountsManager.DeletionReason.USER_REQUEST); + accountsManager.delete(AuthenticationUtil.requireAuthenticatedPrimaryDevice().accountIdentifier(), + AccountsManager.DeletionReason.USER_REQUEST); return DeleteAccountResponse.getDefaultInstance(); } @@ -110,7 +110,7 @@ public class AccountsGrpcService extends SimpleAccountsGrpc.AccountsImplBase { final SaltedTokenHash credentials = SaltedTokenHash.generateFor(HexFormat.of().withLowerCase().formatHex(request.getRegistrationLock().toByteArray())); - accountsManager.update(getAuthenticatedAccount(AuthenticationUtil.requireAuthenticatedPrimaryDevice()), + accountsManager.update(AuthenticationUtil.requireAuthenticatedPrimaryDevice().accountIdentifier(), account -> account.setRegistrationLock(credentials.hash(), credentials.salt())); return SetRegistrationLockResponse.getDefaultInstance(); @@ -118,7 +118,7 @@ public class AccountsGrpcService extends SimpleAccountsGrpc.AccountsImplBase { @Override public ClearRegistrationLockResponse clearRegistrationLock(final ClearRegistrationLockRequest request) { - accountsManager.update(getAuthenticatedAccount(AuthenticationUtil.requireAuthenticatedPrimaryDevice()), + accountsManager.update(AuthenticationUtil.requireAuthenticatedPrimaryDevice().accountIdentifier(), account -> account.setRegistrationLock(null, null)); return ClearRegistrationLockResponse.getDefaultInstance(); @@ -142,11 +142,9 @@ public class AccountsGrpcService extends SimpleAccountsGrpc.AccountsImplBase { rateLimiters.getUsernameReserveLimiter().validate(authenticatedDevice.accountIdentifier()); - final Account account = getAuthenticatedAccount(); - try { final AccountsManager.UsernameReservation usernameReservation = - accountsManager.reserveUsernameHash(account, usernameHashes); + accountsManager.reserveUsernameHash(authenticatedDevice.accountIdentifier(), usernameHashes); return ReserveUsernameHashResponse.newBuilder() .setUsernameHash(ByteString.copyFrom(usernameReservation.reservedUsernameHash())) @@ -172,7 +170,7 @@ public class AccountsGrpcService extends SimpleAccountsGrpc.AccountsImplBase { rateLimiters.getUsernameSetLimiter().validate(authenticatedDevice.accountIdentifier()); try { - final Account updatedAccount = accountsManager.confirmReservedUsernameHash(getAuthenticatedAccount(), + final Account updatedAccount = accountsManager.confirmReservedUsernameHash(authenticatedDevice.accountIdentifier(), request.getUsernameHash().toByteArray(), request.getUsernameCiphertext().toByteArray()); @@ -196,7 +194,7 @@ public class AccountsGrpcService extends SimpleAccountsGrpc.AccountsImplBase { @Override public DeleteUsernameHashResponse deleteUsernameHash(final DeleteUsernameHashRequest request) { - accountsManager.clearUsernameHash(getAuthenticatedAccount()); + accountsManager.clearUsernameHash(AuthenticationUtil.requireAuthenticatedDevice().accountIdentifier()); return DeleteUsernameHashResponse.getDefaultInstance(); } @@ -220,7 +218,8 @@ public class AccountsGrpcService extends SimpleAccountsGrpc.AccountsImplBase { ? account.getUsernameLinkHandle() : UUID.randomUUID(); - accountsManager.update(account, a -> a.setUsernameLinkDetails(linkHandle, request.getUsernameCiphertext().toByteArray())); + accountsManager.update(account.getIdentifier(IdentityType.ACI), + a -> a.setUsernameLinkDetails(linkHandle, request.getUsernameCiphertext().toByteArray())); return responseBuilder.setUsernameLinkHandle(UUIDUtil.toByteString(linkHandle)).build(); } @@ -232,7 +231,7 @@ public class AccountsGrpcService extends SimpleAccountsGrpc.AccountsImplBase { rateLimiters.getUsernameLinkOperationLimiter().validate(authenticatedDevice.accountIdentifier()); - accountsManager.update(getAuthenticatedAccount(), a -> a.setUsernameLinkDetails(null, null)); + accountsManager.update(authenticatedDevice.accountIdentifier(), a -> a.setUsernameLinkDetails(null, null)); return DeleteUsernameLinkResponse.getDefaultInstance(); } @@ -245,7 +244,7 @@ public class AccountsGrpcService extends SimpleAccountsGrpc.AccountsImplBase { UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH, request.getUnidentifiedAccessKey().size())); } - accountsManager.update(getAuthenticatedAccount(), account -> { + accountsManager.update(AuthenticationUtil.requireAuthenticatedDevice().accountIdentifier(), account -> { account.setUnrestrictedUnidentifiedAccess(request.getAllowUnrestrictedUnidentifiedAccess()); account.setUnidentifiedAccessKey(request.getAllowUnrestrictedUnidentifiedAccess() ? null : request.getUnidentifiedAccessKey().toByteArray()); }); @@ -255,7 +254,7 @@ public class AccountsGrpcService extends SimpleAccountsGrpc.AccountsImplBase { @Override public SetDiscoverableByPhoneNumberResponse setDiscoverableByPhoneNumber(final SetDiscoverableByPhoneNumberRequest request) { - accountsManager.update(getAuthenticatedAccount(), + accountsManager.update(AuthenticationUtil.requireAuthenticatedDevice().accountIdentifier(), account -> account.setDiscoverableByPhoneNumber(request.getDiscoverableByPhoneNumber())); return SetDiscoverableByPhoneNumberResponse.getDefaultInstance(); @@ -283,7 +282,7 @@ public class AccountsGrpcService extends SimpleAccountsGrpc.AccountsImplBase { rateLimiters.getSetZkCredentialKeyLimiter().validate(authenticatedDevice.accountIdentifier()); - accountsManager.update(authenticatedAccount, account -> account.setZkCredentialKey(zkCredentialKey)); + accountsManager.update(authenticatedDevice.accountIdentifier(), account -> account.setZkCredentialKey(zkCredentialKey)); return SetZkCredentialKeyResponse.getDefaultInstance(); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/DevicesGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/DevicesGrpcService.java index fefbf1c6f..a8decc712 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/DevicesGrpcService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/DevicesGrpcService.java @@ -80,9 +80,7 @@ public class DevicesGrpcService extends SimpleDevicesGrpc.DevicesImplBase { throw GrpcExceptions.badAuthentication("linked devices cannot remove devices other than themselves"); } - final byte deviceId = DeviceIdUtil.validate(request.getId()); - - accountsManager.removeDevice(getAuthenticatedAccount(), deviceId); + accountsManager.removeDevice(authenticatedDevice.accountIdentifier(), DeviceIdUtil.validate(request.getId())); return RemoveDeviceResponse.getDefaultInstance(); } @@ -106,7 +104,7 @@ public class DevicesGrpcService extends SimpleDevicesGrpc.DevicesImplBase { return SetDeviceNameResponse.newBuilder().setTargetDeviceNotFound(NotFound.getDefaultInstance()).build(); } - accountsManager.updateDevice(account, deviceId, device -> device.setName(request.getName().toByteArray())); + accountsManager.updateDevice(account.getIdentifier(IdentityType.ACI), deviceId, device -> device.setName(request.getName().toByteArray())); return SetDeviceNameResponse.newBuilder().setSuccess(Empty.getDefaultInstance()).build(); } @@ -141,7 +139,7 @@ public class DevicesGrpcService extends SimpleDevicesGrpc.DevicesImplBase { .orElseThrow(() -> GrpcExceptions.invalidCredentials("invalid credentials")); if (!Objects.equals(device.getApnId(), apnsToken) || !Objects.equals(device.getGcmId(), fcmToken)) { - accountsManager.updateDevice(account, authenticatedDevice.deviceId(), d -> { + accountsManager.updateDevice(account.getIdentifier(IdentityType.ACI), authenticatedDevice.deviceId(), d -> { d.setApnId(apnsToken); d.setGcmId(fcmToken); d.setFetchesMessages(false); @@ -154,9 +152,8 @@ public class DevicesGrpcService extends SimpleDevicesGrpc.DevicesImplBase { @Override public ClearPushTokenResponse clearPushToken(final ClearPushTokenRequest request) { final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice(); - final Account account = getAuthenticatedAccount(); - accountsManager.updateDevice(account, authenticatedDevice.deviceId(), device -> { + accountsManager.updateDevice(authenticatedDevice.accountIdentifier(), authenticatedDevice.deviceId(), device -> { if (StringUtils.isNotBlank(device.getApnId())) { device.setUserAgent(device.isPrimary() ? "OWI" : "OWP"); } else if (StringUtils.isNotBlank(device.getGcmId())) { @@ -179,7 +176,7 @@ public class DevicesGrpcService extends SimpleDevicesGrpc.DevicesImplBase { .map(DeviceCapabilityUtil::fromGrpcDeviceCapability) .collect(Collectors.toSet()); - accountsManager.updateDevice(getAuthenticatedAccount(), authenticatedDevice.deviceId(), + accountsManager.updateDevice(authenticatedDevice.accountIdentifier(), authenticatedDevice.deviceId(), device -> device.setCapabilities(capabilities)); return SetCapabilitiesResponse.getDefaultInstance(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ProfileGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ProfileGrpcService.java index 6ebc9f12d..fe11b81a2 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ProfileGrpcService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ProfileGrpcService.java @@ -124,7 +124,7 @@ public class ProfileGrpcService extends SimpleProfileGrpc.ProfileImplBase { } }; - profilesManager.set(account.getUuid(), + profilesManager.set(account.getIdentifier(IdentityType.ACI), new VersionedProfile( request.getVersion(), request.getName().toByteArray(), @@ -135,7 +135,7 @@ public class ProfileGrpcService extends SimpleProfileGrpc.ProfileImplBase { request.getPhoneNumberSharing().toByteArray(), request.getCommitment().toByteArray())); - accountsManager.update(account, a -> { + accountsManager.update(account.getIdentifier(IdentityType.ACI), a -> { final List updatedBadges = Optional.of(request.getBadgeIdsList()) .map(badges -> ProfileHelper.mergeBadgeIdsWithExistingAccountBadges(clock, badgeConfigurationMap, badges, diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/PushNotificationManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/PushNotificationManager.java index 753c311ac..1050d2d0e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/PushNotificationManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/PushNotificationManager.java @@ -17,6 +17,7 @@ import java.util.function.BiConsumer; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; @@ -195,7 +196,7 @@ public class PushNotificationManager { // updating an uninstalled feedback timestamp though. accountsManager.getByAccountIdentifier(account.getUuid()).ifPresent(rereadAccount -> rereadAccount.getDevice(device.getId()).ifPresent(rereadDevice -> - accountsManager.updateDevice(rereadAccount, device.getId(), d -> { + accountsManager.updateDevice(rereadAccount.getIdentifier(IdentityType.ACI), device.getId(), d -> { // Don't clear the token if it's already changed if (originalToken.equals(getPushToken(d, tokenType))) { switch (tokenType) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index 5c162f868..3a766951a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -433,11 +433,15 @@ public class AccountsManager extends RedisPubSubAdapter implemen return account; } - public Pair addDevice(final Account account, final DeviceSpec deviceSpec, final String linkDeviceToken) + public Pair addDevice(final UUID accountIdentifier, final DeviceSpec deviceSpec, final String linkDeviceToken) throws LinkDeviceTokenAlreadyUsedException { - return accountLockManager.withLock(Set.of(account.getPhoneNumberIdentifier()), - () -> addDevice(account.getIdentifier(IdentityType.ACI), deviceSpec, linkDeviceToken, MAX_UPDATE_ATTEMPTS), + final UUID phoneNumberIdentifier = accounts.getByAccountIdentifier(accountIdentifier) + .map(account -> account.getIdentifier(IdentityType.PNI)) + .orElseThrow(() -> new IllegalArgumentException("Account not found: " + accountIdentifier)); + + return accountLockManager.withLock(Set.of(phoneNumberIdentifier), + () -> addDevice(accountIdentifier, deviceSpec, linkDeviceToken, MAX_UPDATE_ATTEMPTS), accountLockExecutor); } @@ -622,13 +626,18 @@ public class AccountsManager extends RedisPubSubAdapter implemen * * @return the updated Account */ - public Account removeDevice(final Account account, final byte deviceId) { + public Account removeDevice(final UUID accountIdentifier, final byte deviceId) { if (deviceId == Device.PRIMARY_ID) { throw new IllegalArgumentException("Cannot remove primary device"); } - return accountLockManager.withLock(Set.of(account.getPhoneNumberIdentifier()), - () -> removeDevice(account.getIdentifier(IdentityType.ACI), deviceId, MAX_UPDATE_ATTEMPTS), + // Always fetch a fresh, non-cached copy of the account before making modifications + final UUID phoneNumberIdentifier = accounts.getByAccountIdentifier(accountIdentifier) + .map(account -> account.getIdentifier(IdentityType.PNI)) + .orElseThrow(() -> new IllegalArgumentException("Account not found: " + accountIdentifier)); + + return accountLockManager.withLock(Set.of(phoneNumberIdentifier), + () -> removeDevice(accountIdentifier, deviceId, MAX_UPDATE_ATTEMPTS), accountLockExecutor); } @@ -672,13 +681,17 @@ public class AccountsManager extends RedisPubSubAdapter implemen } } - public Account changeNumber(final Account account, + public Account changeNumber(final UUID accountIdentifier, final String targetNumber, final IdentityKey pniIdentityKey, final Map pniSignedPreKeys, final Map pniPqLastResortPreKeys, final Map pniRegistrationIds) throws InterruptedException, MismatchedDevicesException { + // Always fetch a fresh, non-cached copy of the account before making modifications + final Account account = accounts.getByAccountIdentifier(accountIdentifier) + .orElseThrow(() -> new IllegalArgumentException("Account not found: " + accountIdentifier)); + final UUID targetPhoneNumberIdentifier = phoneNumberIdentifiers.getPhoneNumberIdentifier(targetNumber).join(); try { @@ -743,7 +756,6 @@ public class AccountsManager extends RedisPubSubAdapter implemen buildPniKeyWriteItems(targetPhoneNumberIdentifier, pniSignedPreKeys, pniPqLastResortPreKeys); return updateWithRetries( - account, a -> { setPniKeys(a, pniIdentityKey, pniRegistrationIds); return true; @@ -815,18 +827,23 @@ public class AccountsManager extends RedisPubSubAdapter implemen /// Reserve a username hash so that no other accounts may take it. /// - /// The reserved hash can later be set with [#confirmReservedUsernameHash(Account, byte\[\], byte\[\])]. The + /// The reserved hash can later be set with [#confirmReservedUsernameHash(UUID, byte\[\], byte\[\])]. The /// reservation will eventually expire, after which point confirmReservedUsernameHash may fail if another account has /// taken the username hash. /// - /// @param account the account to update + /// @param accountIdentifier the unique identifier of the account to update /// @param requestedUsernameHashes the list of username hashes to attempt to reserve /// /// @return the reserved username hash /// /// @throws UsernameHashNotAvailableException if none of the given username hashes are available - public UsernameReservation reserveUsernameHash(final Account account, final List requestedUsernameHashes) + public UsernameReservation reserveUsernameHash(final UUID accountIdentifier, final List requestedUsernameHashes) throws UsernameHashNotAvailableException { + + // Always fetch a fresh, non-cached copy of the account before making modifications + final Account account = accounts.getByAccountIdentifier(accountIdentifier) + .orElseThrow(() -> new IllegalArgumentException("Account not found: " + accountIdentifier)); + if (account.getUsernameHash().filter( oldHash -> requestedUsernameHashes.stream().anyMatch(hash -> Arrays.equals(oldHash, hash))) .isPresent()) { @@ -844,7 +861,6 @@ public class AccountsManager extends RedisPubSubAdapter implemen redisDelete(account); final Account updatedAccount = updateWithRetries( - account, _ -> true, a -> reservedUsernameHash.set( checkAndReserveNextUsernameHash(a, new ArrayDeque<>(requestedUsernameHashes))), @@ -873,9 +889,9 @@ public class AccountsManager extends RedisPubSubAdapter implemen } } - /// Set a username hash previously reserved with {@link #reserveUsernameHash(Account, List)} + /// Set a username hash previously reserved with {@link #reserveUsernameHash(UUID, List)} /// - /// @param account the account to update + /// @param accountIdentifier identifier of the account to update /// @param reservedUsernameHash the previously reserved username hash /// @param encryptedUsername the encrypted form of the previously reserved username for the username link /// @@ -884,9 +900,13 @@ public class AccountsManager extends RedisPubSubAdapter implemen /// @throws UsernameHashNotAvailableException if the reserved username hash has been taken (because the reservation /// expired) /// @throws UsernameReservationNotFoundException if `reservedUsernameHash` was not reserved for the account - public Account confirmReservedUsernameHash(final Account account, final byte[] reservedUsernameHash, @Nullable final byte[] encryptedUsername) + public Account confirmReservedUsernameHash(final UUID accountIdentifier, final byte[] reservedUsernameHash, @Nullable final byte[] encryptedUsername) throws UsernameReservationNotFoundException, UsernameHashNotAvailableException { + // Always fetch a fresh, non-cached copy of the account before making modifications + final Account account = accounts.getByAccountIdentifier(accountIdentifier) + .orElseThrow(() -> new IllegalArgumentException("Account not found: " + accountIdentifier)); + if (account.getUsernameHash().map(currentUsernameHash -> Arrays.equals(currentUsernameHash, reservedUsernameHash)).orElse(false)) { // the client likely already succeeded and is retrying return account; @@ -900,7 +920,7 @@ public class AccountsManager extends RedisPubSubAdapter implemen redisDelete(account); - final Account updatedAccount = updateWithRetries(account, + final Account updatedAccount = updateWithRetries( _ -> true, a -> accounts.confirmUsernameHash(a, reservedUsernameHash, encryptedUsername), () -> accounts.getByAccountIdentifier(account.getUuid()).orElseThrow(), @@ -911,13 +931,10 @@ public class AccountsManager extends RedisPubSubAdapter implemen return updatedAccount; } - public Account clearUsernameHash(final Account account) { - redisDelete(account); - - final Account updatedAccount = updateWithRetries(account, - _ -> true, + public Account clearUsernameHash(final UUID accountIdentifier) { + final Account updatedAccount = updateWithRetries(_ -> true, accounts::clearUsernameHash, - () -> accounts.getByAccountIdentifier(account.getUuid()).orElseThrow(), + () -> accounts.getByAccountIdentifier(accountIdentifier).orElseThrow(), AccountChangeValidator.USERNAME_CHANGE_VALIDATOR); redisDelete(updatedAccount); @@ -925,8 +942,15 @@ public class AccountsManager extends RedisPubSubAdapter implemen return updatedAccount; } - public Account update(Account account, Consumer updater) { - return update(account, a -> { + public Account update(final Account account, final Consumer updater) { + final Account updatedAccount = update(account.getIdentifier(IdentityType.ACI), updater); + account.markStale(); + + return updatedAccount; + } + + public Account update(final UUID accountIdentifier, final Consumer updater) { + return update(accountIdentifier, a -> { updater.accept(a); // assume that all updaters passed to the public method actually modify the account return true; @@ -934,11 +958,11 @@ public class AccountsManager extends RedisPubSubAdapter implemen } /** - * Specialized version of {@link #updateDevice(Account, byte, Consumer)} that minimizes potentially contentious and + * Specialized version of {@link #updateDevice(UUID, byte, Consumer)} that minimizes potentially contentious and * redundant updates of {@code device.lastSeen} */ - public Account updateDeviceLastSeen(Account account, Device device, final long lastSeen) { - return update(account, a -> { + public Account updateDeviceLastSeen(UUID accountIdentifier, Device device, final long lastSeen) { + return update(accountIdentifier, a -> { final Optional maybeDevice = a.getDevice(device.getId()); @@ -956,21 +980,16 @@ public class AccountsManager extends RedisPubSubAdapter implemen } /** - * @param account account to update + * @param accountIdentifier identifier of account to update * @param updater must return {@code true} if the account was actually updated */ - private Account update(Account account, Function updater) { + private Account update(UUID accountIdentifier, Function updater) { return updateTimer.record(() -> { - redisDelete(account); - - final UUID uuid = account.getUuid(); - - final Account updatedAccount = updateWithRetries(account, - updater, + final Account updatedAccount = updateWithRetries(updater, accounts::update, - () -> accounts.getByAccountIdentifier(uuid).orElseThrow(), + () -> accounts.getByAccountIdentifier(accountIdentifier).orElseThrow(), AccountChangeValidator.GENERAL_CHANGE_VALIDATOR); redisSet(updatedAccount); @@ -979,49 +998,38 @@ public class AccountsManager extends RedisPubSubAdapter implemen }); } - private Account updateWithRetries(Account account, - final Function updater, + private Account updateWithRetries(final Function updater, final ThrowingConsumer persister, final Supplier retriever, final AccountChangeValidator changeValidator) throws E { - Account originalAccount = AccountUtil.cloneAccountAsNotStale(account); - - if (!updater.apply(account)) { - return account; - } - final int maxTries = 10; int tries = 0; while (tries < maxTries) { - try { - persister.accept(account); - - final Account updatedAccount = AccountUtil.cloneAccountAsNotStale(account); - account.markStale(); - - changeValidator.validateChange(originalAccount, updatedAccount); - - return updatedAccount; - } catch (final ContestedOptimisticLockException e) { - tries++; - - account = retriever.get(); - originalAccount = AccountUtil.cloneAccountAsNotStale(account); + final Account account = retriever.get(); + final Account originalAccount = AccountUtil.cloneAccountAsNotStale(account); if (!updater.apply(account)) { return account; } + + persister.accept(account); + + changeValidator.validateChange(originalAccount, account); + + return account; + } catch (final ContestedOptimisticLockException e) { + tries++; } } throw new OptimisticLockRetryLimitExceededException(); } - public Account updateDevice(Account account, byte deviceId, Consumer deviceUpdater) { - return update(account, a -> { + public Account updateDevice(final UUID accountIdentifier, byte deviceId, Consumer deviceUpdater) { + return update(accountIdentifier, a -> { a.getDevice(deviceId).ifPresent(deviceUpdater); // assume that all updaters passed to the public method actually modify the device return true; @@ -1106,9 +1114,19 @@ public class AccountsManager extends RedisPubSubAdapter implemen return accounts.getAll(segments, scheduler); } - public void delete(final Account account, final DeletionReason deletionReason) { + public void delete(final UUID accountIdentifier, final DeletionReason deletionReason) { final Timer.Sample sample = Timer.start(); + final Optional maybeAccount = accounts.getByAccountIdentifier(accountIdentifier); + + if (maybeAccount.isEmpty()) { + // In most cases, failing to find an account would be an error, but in this case, it's probably a sign that we've + // already succeeded in deleting the account and this is a spurious retry. + return; + } + + final Account account = maybeAccount.get(); + try { accountLockManager.withLock(Set.of(account.getPhoneNumberIdentifier()), () -> { delete(account); @@ -1328,21 +1346,6 @@ public class AccountsManager extends RedisPubSubAdapter implemen getAccountEntityKey(account.getUuid()))))); } - private CompletableFuture redisDeleteAsync(final Account account) { - final Timer.Sample sample = Timer.start(); - - final String[] keysToDelete = new String[]{ - getAccountMapKey(account.getPhoneNumberIdentifier().toString()), - getAccountEntityKey(account.getUuid()) - }; - - return ResilienceUtil.getGeneralRedisRetry(RETRY_NAME).executeCompletionStage(retryExecutor, - () -> cacheCluster.withCluster(connection -> connection.async().del(keysToDelete)) - .thenRun(Util.NOOP)) - .toCompletableFuture() - .whenComplete((_, _) -> sample.stop(redisDeleteTimer)); - } - public CompletableFuture> waitForNewLinkedDevice( final UUID accountIdentifier, final Device linkingDevice, diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java index 4cf04b624..f0fb4d7b2 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java @@ -80,7 +80,7 @@ public class ChangeNumberManager { } final Account updatedAccount = accountsManager.changeNumber( - account, number, pniIdentityKey, deviceSignedPreKeys, devicePqLastResortPreKeys, pniRegistrationIds); + account.getIdentifier(IdentityType.ACI), number, pniIdentityKey, deviceSignedPreKeys, devicePqLastResortPreKeys, pniRegistrationIds); try { // Now that we've actually updated the account, populate the "updated PNI" field on all envelopes diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java index 10bc1e1f4..3eb38a615 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java @@ -311,7 +311,7 @@ public class MessagePersister implements Managed { } else { logger.warn("Failed to persist queue {}::{} due to overfull queue; will unlink device", accountUuid, deviceId); - return Mono.fromRunnable(() -> accountsManager.removeDevice(account, deviceId)) + return Mono.fromRunnable(() -> accountsManager.removeDevice(accountUuid, deviceId)) .subscribeOn(persistQueueScheduler) .then(Mono.empty()); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/DeleteUserCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/DeleteUserCommand.java index d464f7e01..2efa2521f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/DeleteUserCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/DeleteUserCommand.java @@ -13,6 +13,7 @@ import net.sourceforge.argparse4j.inf.Subparser; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.WhisperServerConfiguration; +import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager.DeletionReason; @@ -51,7 +52,7 @@ public class DeleteUserCommand extends AbstractCommandWithDependencies { Optional account = accountsManager.getByE164(user); if (account.isPresent()) { - accountsManager.delete(account.get(), DeletionReason.ADMIN_DELETED); + accountsManager.delete(account.get().getIdentifier(IdentityType.ACI), DeletionReason.ADMIN_DELETED); logger.warn("Removed " + account.get().getNumber()); } else { logger.warn("Account not found"); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredAccountsCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredAccountsCommand.java index 2e59594cb..0d5ed0280 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredAccountsCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredAccountsCommand.java @@ -16,6 +16,7 @@ import java.time.Instant; import net.sourceforge.argparse4j.inf.Subparser; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import reactor.core.publisher.Mono; @@ -67,7 +68,7 @@ public class RemoveExpiredAccountsCommand extends AbstractSinglePassCrawlAccount .flatMap(expiredAccount -> { final Mono deleteAccountMono = isDryRun ? Mono.empty() - : Mono.fromRunnable(() -> getCommandDependencies().accountsManager().delete(expiredAccount, AccountsManager.DeletionReason.EXPIRED)) + : Mono.fromRunnable(() -> getCommandDependencies().accountsManager().delete(expiredAccount.getIdentifier(IdentityType.ACI), AccountsManager.DeletionReason.EXPIRED)) .subscribeOn(Schedulers.boundedElastic()) .then(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredLinkedDevicesCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredLinkedDevicesCommand.java index b9e253f0b..c6481bf4a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredLinkedDevicesCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredLinkedDevicesCommand.java @@ -21,6 +21,7 @@ import java.util.stream.Collectors; import net.sourceforge.argparse4j.inf.Subparser; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import reactor.core.publisher.Flux; @@ -130,7 +131,7 @@ public class RemoveExpiredLinkedDevicesCommand extends AbstractSinglePassCrawlAc return Flux.fromIterable(expiredDevices) .flatMap(deviceId -> - Mono.fromRunnable(() -> getCommandDependencies().accountsManager().removeDevice(account, deviceId)) + Mono.fromRunnable(() -> getCommandDependencies().accountsManager().removeDevice(account.getIdentifier(IdentityType.ACI), deviceId)) .retryWhen(Retry.backoff(maxRetries, Duration.ofSeconds(1)) .doAfterRetry(ignored -> retryCounter.increment()) .onRetryExhaustedThrow((spec, rs) -> rs.failure())) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredUsernameHoldsCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredUsernameHoldsCommand.java index 427727471..fb90e9f93 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredUsernameHoldsCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredUsernameHoldsCommand.java @@ -20,6 +20,7 @@ import java.util.concurrent.atomic.AtomicLong; import net.sourceforge.argparse4j.inf.Subparser; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import reactor.core.publisher.Flux; @@ -85,7 +86,7 @@ public class RemoveExpiredUsernameHoldsCommand extends AbstractSinglePassCrawlAc final int holdsToRemove = removeExpired(holds); final Mono purgeMono = isDryRun || holdsToRemove == 0 ? Mono.empty() - : Mono.fromRunnable(() -> accountManager.update(account, a -> a.setUsernameHolds(holds))) + : Mono.fromRunnable(() -> accountManager.update(account.getIdentifier(IdentityType.ACI), a -> a.setUsernameHolds(holds))) .subscribeOn(Schedulers.boundedElastic()) .then(); Metrics.counter(INSPECTED_ACCOUNTS_COUNTER_NAME, diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/SetUserDiscoverabilityCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/SetUserDiscoverabilityCommand.java index 39e4a69a8..af93a46db 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/SetUserDiscoverabilityCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/SetUserDiscoverabilityCommand.java @@ -12,6 +12,7 @@ import java.util.UUID; import net.sourceforge.argparse4j.inf.Namespace; import net.sourceforge.argparse4j.inf.Subparser; import org.whispersystems.textsecuregcm.WhisperServerConfiguration; +import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; @@ -61,7 +62,7 @@ public class SetUserDiscoverabilityCommand extends AbstractCommandWithDependenci maybeAccount.ifPresentOrElse(account -> { final boolean initiallyDiscoverable = account.isDiscoverableByPhoneNumber(); - accountsManager.update(account, a -> a.setDiscoverableByPhoneNumber(namespace.getBoolean("discoverable"))); + accountsManager.update(account.getIdentifier(IdentityType.ACI), a -> a.setDiscoverableByPhoneNumber(namespace.getBoolean("discoverable"))); System.out.format("Set discoverability flag for %s to %s (was previously %s)\n", namespace.getString("user"), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/UnlinkDeviceCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/UnlinkDeviceCommand.java index 4adf6a375..1382d3f29 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/UnlinkDeviceCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/UnlinkDeviceCommand.java @@ -13,6 +13,7 @@ import net.sourceforge.argparse4j.impl.Arguments; import net.sourceforge.argparse4j.inf.Namespace; import net.sourceforge.argparse4j.inf.Subparser; import org.whispersystems.textsecuregcm.WhisperServerConfiguration; +import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; @@ -51,9 +52,6 @@ public class UnlinkDeviceCommand extends AbstractCommandWithDependencies { final UUID aci = UUID.fromString(namespace.getString("uuid").trim()); final List deviceIds = namespace.getList("deviceIds"); - Account account = deps.accountsManager().getByAccountIdentifier(aci) - .orElseThrow(() -> new IllegalArgumentException("account id " + aci + " does not exist")); - if (deviceIds.contains(Device.PRIMARY_ID)) { throw new IllegalArgumentException("cannot delete primary device"); } @@ -61,7 +59,7 @@ public class UnlinkDeviceCommand extends AbstractCommandWithDependencies { for (byte deviceId : deviceIds) { /** see {@link org.whispersystems.textsecuregcm.controllers.DeviceController#removeDevice} */ System.out.format("Removing device %s::%d\n", aci, deviceId); - deps.accountsManager().removeDevice(account, deviceId); + deps.accountsManager().removeDevice(aci, deviceId); } } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/UnlinkDevicesWithIdlePrimaryCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/UnlinkDevicesWithIdlePrimaryCommand.java index a83ebfb00..81f665a54 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/UnlinkDevicesWithIdlePrimaryCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/UnlinkDevicesWithIdlePrimaryCommand.java @@ -93,7 +93,7 @@ public class UnlinkDevicesWithIdlePrimaryCommand extends AbstractSinglePassCrawl .filter(account -> isPrimaryDeviceIdle(account, currentTime, idleDurationThreshold)) .flatMap(accountWithIdlePrimaryDevice -> Flux.fromIterable(accountWithIdlePrimaryDevice.getDevices()) .filter(device -> !device.isPrimary()) - .map(linkedDevice -> Tuples.of(accountWithIdlePrimaryDevice, linkedDevice.getId()))) + .map(linkedDevice -> Tuples.of(accountWithIdlePrimaryDevice.getIdentifier(IdentityType.ACI), linkedDevice.getId()))) .flatMap(accountAndLinkedDeviceId -> { final Mono unlinkDeviceMono = isDryRun ? Mono.empty() @@ -104,8 +104,7 @@ public class UnlinkDevicesWithIdlePrimaryCommand extends AbstractSinglePassCrawl .doOnSuccess(ignored -> unlinkDeviceCounter.increment()) .retryWhen(Retry.backoff(3, Duration.ofSeconds(1)).maxBackoff(Duration.ofSeconds(4))) .onErrorResume(throwable -> { - logger.warn("Failed to unlink device to delete account {}:{}", accountAndLinkedDeviceId.getT1().getIdentifier( - IdentityType.ACI), accountAndLinkedDeviceId.getT2(), throwable); + logger.warn("Failed to unlink device to delete account {}:{}", accountAndLinkedDeviceId.getT1(), accountAndLinkedDeviceId.getT2(), throwable); return Mono.empty(); }); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/AccountAuthenticatorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/AccountAuthenticatorTest.java index 7b72a87e6..fdae44ca6 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/AccountAuthenticatorTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/AccountAuthenticatorTest.java @@ -64,6 +64,9 @@ class AccountAuthenticatorTest { oldAccount = AccountsHelper.generateTestAccount("+14088675311", UUID.fromString("adfce52b-9299-4c25-9c51-412fb420c6a6"), UUID.randomUUID(), List.of(generateTestDevice(oldTime)), null); AccountsHelper.setupMockUpdate(accountsManager); + AccountsHelper.setupMockGet(accountsManager, acct1); + AccountsHelper.setupMockGet(accountsManager, acct2); + AccountsHelper.setupMockGet(accountsManager, oldAccount); } private static Device generateTestDevice(final long lastSeen) { @@ -81,17 +84,14 @@ class AccountAuthenticatorTest { final Device device1 = acct1.getDevices().stream().findFirst().orElseThrow(); final Device device2 = acct2.getDevices().stream().findFirst().orElseThrow(); - final Account updatedAcct1 = accountAuthenticator.updateLastSeen(acct1, device1); - final Account updatedAcct2 = accountAuthenticator.updateLastSeen(acct2, device2); + accountAuthenticator.updateLastSeen(acct1, device1); + accountAuthenticator.updateLastSeen(acct2, device2); - verify(accountsManager, never()).updateDeviceLastSeen(eq(acct1), any(), anyLong()); - verify(accountsManager).updateDeviceLastSeen(eq(acct2), eq(device2), anyLong()); + verify(accountsManager, never()).updateDeviceLastSeen(eq(acct1.getIdentifier(IdentityType.ACI)), any(), anyLong()); + verify(accountsManager).updateDeviceLastSeen(eq(acct2.getIdentifier(IdentityType.ACI)), eq(device2), anyLong()); assertThat(device1.getLastSeen()).isEqualTo(yesterday); assertThat(device2.getLastSeen()).isEqualTo(today); - - assertThat(acct1).isSameAs(updatedAcct1); - assertThat(acct2).isNotSameAs(updatedAcct2); } @Test @@ -101,17 +101,14 @@ class AccountAuthenticatorTest { final Device device1 = acct1.getDevices().stream().findFirst().orElseThrow(); final Device device2 = acct2.getDevices().stream().findFirst().orElseThrow(); - final Account updatedAcct1 = accountAuthenticator.updateLastSeen(acct1, device1); - final Account updatedAcct2 = accountAuthenticator.updateLastSeen(acct2, device2); + accountAuthenticator.updateLastSeen(acct1, device1); + accountAuthenticator.updateLastSeen(acct2, device2); - verify(accountsManager, never()).updateDeviceLastSeen(eq(acct1), any(), anyLong()); - verify(accountsManager, never()).updateDeviceLastSeen(eq(acct2), any(), anyLong()); + verify(accountsManager, never()).updateDeviceLastSeen(eq(acct1.getIdentifier(IdentityType.ACI)), any(), anyLong()); + verify(accountsManager, never()).updateDeviceLastSeen(eq(acct2.getIdentifier(IdentityType.ACI)), any(), anyLong()); assertThat(device1.getLastSeen()).isEqualTo(yesterday); assertThat(device2.getLastSeen()).isEqualTo(yesterday); - - assertThat(acct1).isSameAs(updatedAcct1); - assertThat(acct2).isSameAs(updatedAcct2); } @Test @@ -121,17 +118,14 @@ class AccountAuthenticatorTest { final Device device1 = acct1.getDevices().stream().findFirst().orElseThrow(); final Device device2 = acct2.getDevices().stream().findFirst().orElseThrow(); - final Account updatedAcct1 = accountAuthenticator.updateLastSeen(acct1, device1); - final Account updatedAcct2 = accountAuthenticator.updateLastSeen(acct2, device2); + accountAuthenticator.updateLastSeen(acct1, device1); + accountAuthenticator.updateLastSeen(acct2, device2); - verify(accountsManager).updateDeviceLastSeen(eq(acct1), eq(device1), anyLong()); - verify(accountsManager).updateDeviceLastSeen(eq(acct2), eq(device2), anyLong()); + verify(accountsManager).updateDeviceLastSeen(eq(acct1.getIdentifier(IdentityType.ACI)), eq(device1), anyLong()); + verify(accountsManager).updateDeviceLastSeen(eq(acct2.getIdentifier(IdentityType.ACI)), eq(device2), anyLong()); assertThat(device1.getLastSeen()).isEqualTo(today); assertThat(device2.getLastSeen()).isEqualTo(today); - - assertThat(updatedAcct1).isNotSameAs(acct1); - assertThat(updatedAcct2).isNotSameAs(acct2); } @Test @@ -142,7 +136,7 @@ class AccountAuthenticatorTest { accountAuthenticator.updateLastSeen(oldAccount, device); - verify(accountsManager).updateDeviceLastSeen(eq(oldAccount), eq(device), anyLong()); + verify(accountsManager).updateDeviceLastSeen(eq(oldAccount.getIdentifier(IdentityType.ACI)), eq(device), anyLong()); assertThat(device.getLastSeen()).isEqualTo(today); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManagerTest.java index cdf16b814..3d1200578 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManagerTest.java @@ -76,10 +76,14 @@ class RegistrationLockVerificationManagerTest { AccountsHelper.setupMockUpdate(accountsManager); account = mock(Account.class); - when(account.getUuid()).thenReturn(UUID.randomUUID()); + final UUID accountIdentifier = UUID.randomUUID(); + when(account.getUuid()).thenReturn(accountIdentifier); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier); when(account.getNumber()).thenReturn("+18005551212"); when(account.getDevices()).thenReturn(List.of(device)); + AccountsHelper.setupMockGet(accountsManager, account); + existingRegistrationLock = mock(StoredRegistrationLock.class); when(account.getRegistrationLock()).thenReturn(existingRegistrationLock); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthManagerTest.java index 6e3989316..c056203e7 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthManagerTest.java @@ -55,6 +55,7 @@ import org.signal.libsignal.zkgroup.receipts.ServerZkReceiptOperations; import org.whispersystems.textsecuregcm.auth.RedemptionRange; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; +import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiterConfig; import org.whispersystems.textsecuregcm.limits.RateLimiters; @@ -62,6 +63,7 @@ import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.RedeemedReceiptsManager; +import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.ExperimentHelper; import org.whispersystems.textsecuregcm.util.TestClock; import org.whispersystems.textsecuregcm.util.TestRandomUtil; @@ -110,14 +112,10 @@ public class BackupAuthManagerTest { final Account account = mock(Account.class); when(account.getUuid()).thenReturn(aci); - when(accountsManager.update(any(), any())) - .thenAnswer(invocation -> { - final Account a = invocation.getArgument(0); - final Consumer updater = invocation.getArgument(1); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(aci); - updater.accept(a); - return a; - }); + AccountsHelper.setupMockGet(accountsManager, account); + AccountsHelper.setupMockUpdate(accountsManager); final BackupAuthCredentialRequest messagesCredentialRequest = backupAuthTestUtil.getRequest(messagesBackupKey, aci); final BackupAuthCredentialRequest mediaCredentialRequest = backupAuthTestUtil.getRequest(mediaBackupKey, aci); @@ -135,7 +133,7 @@ public class BackupAuthManagerTest { void commitOnAnyBackupLevel(final BackupLevel backupLevel) { final BackupAuthManager authManager = create(); final Account account = new MockAccountBuilder().backupLevel(backupLevel).build(); - when(accountsManager.update(any(), any())).thenReturn(account); + when(accountsManager.update(any(Account.class), any())).thenReturn(account); final ThrowableAssert.ThrowingCallable commit = () -> authManager.commitBackupId(account, @@ -149,7 +147,7 @@ public class BackupAuthManagerTest { void commitRequiresPrimary() { final BackupAuthManager authManager = create(); final Account account = new MockAccountBuilder().build(); - when(accountsManager.update(any(), any())).thenReturn(account); + when(accountsManager.update(any(Account.class), any())).thenReturn(account); final ThrowableAssert.ThrowingCallable commit = () -> authManager.commitBackupId(account, @@ -306,7 +304,7 @@ public class BackupAuthManagerTest { .backupVoucher(null) .build(); - when(accountsManager.update(any(), any())).thenReturn(updated); + when(accountsManager.update(any(Account.class), any())).thenReturn(updated); clock.pin(day2.plus(Duration.ofSeconds(1))); assertThat(authManager @@ -316,7 +314,7 @@ public class BackupAuthManagerTest { @SuppressWarnings("unchecked") final ArgumentCaptor> accountUpdater = ArgumentCaptor.forClass( Consumer.class); - verify(accountsManager, times(1)).update(any(), accountUpdater.capture()); + verify(accountsManager, times(1)).update(any(Account.class), accountUpdater.capture()); // If the account is not expired when we go to update it, we shouldn't wipe it out final Account alreadyUpdated = mock(Account.class); @@ -341,11 +339,11 @@ public class BackupAuthManagerTest { .mediaCredential(Optional.of(new byte[0])) .build(); clock.pin(Instant.EPOCH.plus(Duration.ofDays(1))); - when(accountsManager.update(any(), any())).thenReturn(account); + when(accountsManager.update(any(Account.class), any())).thenReturn(account); when(redeemedReceiptsManager.put(any(), eq(expirationTime.getEpochSecond()), eq(201L), eq(aci))) .thenReturn(CompletableFuture.completedFuture(true)); authManager.redeemReceipt(account, receiptPresentation(201, expirationTime)); - verify(accountsManager, times(1)).update(any(), any()); + verify(accountsManager, times(1)).update(any(Account.class), any()); } @Test @@ -375,13 +373,13 @@ public class BackupAuthManagerTest { .build(); clock.pin(Instant.EPOCH.plus(Duration.ofDays(1))); - when(accountsManager.update(any(), any())).thenReturn(account); + when(accountsManager.update(any(Account.class), any())).thenReturn(account); when(redeemedReceiptsManager.put(any(), eq(newExpirationTime.getEpochSecond()), eq(201L), eq(aci))) .thenReturn(CompletableFuture.completedFuture(true)); authManager.redeemReceipt(account, receiptPresentation(201, newExpirationTime)); final ArgumentCaptor> updaterCaptor = ArgumentCaptor.captor(); - verify(accountsManager, times(1)).update(any(), updaterCaptor.capture()); + verify(accountsManager, times(1)).update(any(Account.class), updaterCaptor.capture()); updaterCaptor.getValue().accept(account); // Should select the voucher with the later expiration time @@ -430,7 +428,7 @@ public class BackupAuthManagerTest { .build(); clock.pin(Instant.EPOCH.plus(Duration.ofDays(1))); - when(accountsManager.update(any(), any())).thenReturn(account); + when(accountsManager.update(any(Account.class), any())).thenReturn(account); when(redeemedReceiptsManager.put(any(), eq(expirationTime.getEpochSecond()), eq(201L), eq(aci))) .thenReturn(CompletableFuture.completedFuture(false)); @@ -510,7 +508,7 @@ public class BackupAuthManagerTest { .backupVoucher(backupVoucher) .build(); - when(accountsManager.update(any(), any())).thenReturn(account); + when(accountsManager.update(any(Account.class), any())).thenReturn(account); final Optional newMessagesCredential = switch (messageChange) { case MATCH -> Optional.of(storedMessagesCredential); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java index 592d315f3..b7f72f7d2 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java @@ -293,7 +293,7 @@ class AccountControllerTest { assertThat(response.getStatus()).isEqualTo(204); verify(AuthHelper.VALID_DEVICE_3_PRIMARY, times(1)).setGcmId(eq("z000")); - verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.VALID_ACCOUNT_3), anyByte(), any()); + verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.VALID_UUID_3), anyByte(), any()); } } @@ -322,7 +322,7 @@ class AccountControllerTest { assertThat(response.getStatus()).isEqualTo(204); verify(AuthHelper.VALID_DEVICE_3_PRIMARY, times(1)).setApnId(eq("first")); - verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.VALID_ACCOUNT_3), anyByte(), any()); + verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.VALID_UUID_3), anyByte(), any()); } } @@ -406,7 +406,7 @@ class AccountControllerTest { } // make sure `update()` works - doReturn(AuthHelper.VALID_ACCOUNT).when(accountsManager).update(any(), any()); + doReturn(AuthHelper.VALID_ACCOUNT).when(accountsManager).update(any(UUID.class), any()); try (final Response response = builder.put(Entity.json(new EncryptedUsername(TestRandomUtil.nextBytes(payloadSize))))) { @@ -449,7 +449,7 @@ class AccountControllerTest { } // make sure `update()` works - doReturn(AuthHelper.VALID_ACCOUNT).when(accountsManager).update(any(), any()); + doReturn(AuthHelper.VALID_ACCOUNT).when(accountsManager).update(any(UUID.class), any()); try (final Response delete = builder.delete()) { assertEquals(expectedStatus, delete.getStatus()); @@ -753,7 +753,7 @@ class AccountControllerTest { @Test void testDeleteUsername() { when(accountsManager.clearUsernameHash(any())) - .thenAnswer(invocation -> invocation.getArgument(0)); + .thenAnswer(invocation -> accountsManager.getByAccountIdentifier(invocation.getArgument(0)).orElseThrow()); try (final Response response = resources.getJerseyTest() .target("/v1/accounts/username_hash/") @@ -762,7 +762,7 @@ class AccountControllerTest { .delete()) { assertThat(response.getStatus()).isEqualTo(204); - verify(accountsManager).clearUsernameHash(AuthHelper.VALID_ACCOUNT); + verify(accountsManager).clearUsernameHash(AuthHelper.VALID_UUID); } } @@ -887,7 +887,7 @@ class AccountControllerTest { .delete()) { assertThat(response.getStatus()).isEqualTo(204); - verify(accountsManager).delete(AuthHelper.VALID_ACCOUNT, AccountsManager.DeletionReason.USER_REQUEST); + verify(accountsManager).delete(AuthHelper.VALID_UUID, AccountsManager.DeletionReason.USER_REQUEST); } } @@ -903,7 +903,7 @@ class AccountControllerTest { .delete()) { assertThat(response.getStatus()).isEqualTo(500); - verify(accountsManager).delete(AuthHelper.VALID_ACCOUNT, AccountsManager.DeletionReason.USER_REQUEST); + verify(accountsManager).delete(AuthHelper.VALID_UUID, AccountsManager.DeletionReason.USER_REQUEST); } } @@ -1074,7 +1074,7 @@ class AccountControllerTest { .put(Entity.json(new DeviceName(TestRandomUtil.nextBytes(64))))) { assertThat(response.getStatus()).isEqualTo(204); - verify(accountsManager).updateDevice(eq(AuthHelper.VALID_ACCOUNT_3), eq(Device.PRIMARY_ID), any()); + verify(accountsManager).updateDevice(eq(AuthHelper.VALID_UUID_3), eq(Device.PRIMARY_ID), any()); } } @@ -1088,7 +1088,7 @@ class AccountControllerTest { .put(Entity.json(new DeviceName(TestRandomUtil.nextBytes(64))))) { assertThat(response.getStatus()).isEqualTo(204); - verify(accountsManager).updateDevice(eq(AuthHelper.VALID_ACCOUNT_3), eq(AuthHelper.VALID_DEVICE_3_LINKED_ID), any()); + verify(accountsManager).updateDevice(eq(AuthHelper.VALID_UUID_3), eq(AuthHelper.VALID_DEVICE_3_LINKED_ID), any()); } } @@ -1102,7 +1102,7 @@ class AccountControllerTest { .put(Entity.json(new DeviceName(TestRandomUtil.nextBytes(64))))) { assertThat(response.getStatus()).isEqualTo(204); - verify(accountsManager).updateDevice(eq(AuthHelper.VALID_ACCOUNT_3), eq(AuthHelper.VALID_DEVICE_3_LINKED_ID), any()); + verify(accountsManager).updateDevice(eq(AuthHelper.VALID_UUID_3), eq(AuthHelper.VALID_DEVICE_3_LINKED_ID), any()); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java index 765e11b4e..4bcb3ade8 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java @@ -30,8 +30,6 @@ import jakarta.ws.rs.client.Entity; import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.core.Response; import java.nio.charset.StandardCharsets; -import java.time.Instant; -import java.time.temporal.ChronoUnit; import java.util.Arrays; import java.util.Base64; import java.util.EnumSet; @@ -141,7 +139,9 @@ class DeviceControllerTest { when(account.getNextDeviceId()).thenReturn(NEXT_DEVICE_ID); when(account.getNumber()).thenReturn(AuthHelper.VALID_NUMBER); when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(AuthHelper.VALID_UUID); when(account.getPhoneNumberIdentifier()).thenReturn(AuthHelper.VALID_PNI); + when(account.getIdentifier(IdentityType.PNI)).thenReturn(AuthHelper.VALID_PNI); when(account.getPrimaryDevice()).thenReturn(primaryDevice); when(account.getDevice(anyByte())).thenReturn(Optional.empty()); when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(primaryDevice)); @@ -244,7 +244,7 @@ class DeviceControllerTest { when(accountsManager.checkDeviceLinkingToken(anyString())).thenReturn(Optional.of(AuthHelper.VALID_UUID)); when(accountsManager.addDevice(any(), any(), any())).thenAnswer(invocation -> { - final Account a = invocation.getArgument(0); + final Account a = accountsManager.getByAccountIdentifier(invocation.getArgument(0)).orElseThrow(); final DeviceSpec deviceSpec = invocation.getArgument(1); return new Pair<>(a, deviceSpec.toDevice(NEXT_DEVICE_ID, testClock, aciIdentityKey)); @@ -268,7 +268,7 @@ class DeviceControllerTest { assertThat(response.deviceId()).isEqualTo(NEXT_DEVICE_ID); final ArgumentCaptor deviceSpecCaptor = ArgumentCaptor.forClass(DeviceSpec.class); - verify(accountsManager).addDevice(eq(account), deviceSpecCaptor.capture(), any()); + verify(accountsManager).addDevice(eq(AuthHelper.VALID_UUID), deviceSpecCaptor.capture(), any()); final Device device = deviceSpecCaptor.getValue().toDevice(NEXT_DEVICE_ID, testClock, aciIdentityKey); @@ -796,7 +796,7 @@ class DeviceControllerTest { when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey())); when(accountsManager.addDevice(any(), any(), any())).thenAnswer(invocation -> { - final Account a = invocation.getArgument(0); + final Account a = accountsManager.getByAccountIdentifier(invocation.getArgument(0)).orElseThrow(); final DeviceSpec deviceSpec = invocation.getArgument(1); return new Pair<>(a, deviceSpec.toDevice(NEXT_DEVICE_ID, testClock, aciIdentityKey)); @@ -886,7 +886,7 @@ class DeviceControllerTest { @Test void maxDevicesTest() throws LinkDeviceTokenAlreadyUsedException { final List devices = IntStream.range(0, DeviceController.MAX_DEVICES + 1) - .mapToObj(i -> mock(Device.class)) + .mapToObj(_ -> mock(Device.class)) .toList(); when(account.getDevices()).thenReturn(devices); @@ -939,7 +939,7 @@ class DeviceControllerTest { final byte deviceId = 2; - when(accountsManager.removeDevice(account, deviceId)) + when(accountsManager.removeDevice(AuthHelper.VALID_UUID, deviceId)) .thenReturn(account); try (final Response response = resources @@ -953,7 +953,7 @@ class DeviceControllerTest { assertThat(response.getStatus()).isEqualTo(204); assertThat(response.hasEntity()).isFalse(); - verify(accountsManager).removeDevice(account, deviceId); + verify(accountsManager).removeDevice(AuthHelper.VALID_UUID, deviceId); } } @@ -980,7 +980,7 @@ class DeviceControllerTest { void removeDeviceBySelf() { final byte deviceId = 2; - when(accountsManager.removeDevice(AuthHelper.VALID_ACCOUNT_3, deviceId)) + when(accountsManager.removeDevice(AuthHelper.VALID_UUID_3, deviceId)) .thenReturn(account); when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_3)) @@ -997,7 +997,7 @@ class DeviceControllerTest { assertThat(response.getStatus()).isEqualTo(204); assertThat(response.hasEntity()).isFalse(); - verify(accountsManager).removeDevice(AuthHelper.VALID_ACCOUNT_3, deviceId); + verify(accountsManager).removeDevice(AuthHelper.VALID_UUID_3, deviceId); } } @@ -1173,7 +1173,6 @@ class DeviceControllerTest { void recordTransferArchiveFailed() { final byte deviceId = Device.PRIMARY_ID + 1; final int registrationId = 123; - final Instant deviceCreated = Instant.now().truncatedTo(ChronoUnit.MILLIS); final RemoteAttachmentError transferFailure = new RemoteAttachmentError(RemoteAttachmentError.ErrorType.CONTINUE_WITHOUT_UPLOAD); when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null)); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java index c98b694ab..84b36a349 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java @@ -197,6 +197,7 @@ class ProfileControllerTest { when(profileAccount.getIdentityKey(IdentityType.ACI)).thenReturn(ACCOUNT_TWO_IDENTITY_KEY); when(profileAccount.getIdentityKey(IdentityType.PNI)).thenReturn(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY); when(profileAccount.getUuid()).thenReturn(AuthHelper.VALID_UUID_TWO); + when(profileAccount.getIdentifier(IdentityType.ACI)).thenReturn(AuthHelper.VALID_UUID_TWO); when(profileAccount.getPhoneNumberIdentifier()).thenReturn(AuthHelper.VALID_PNI_TWO); when(profileAccount.getCurrentProfileVersion()).thenReturn(Optional.empty()); when(profileAccount.getUsernameHash()).thenReturn(Optional.of(USERNAME_HASH)); @@ -207,6 +208,7 @@ class ProfileControllerTest { capabilitiesAccount = mock(Account.class); when(capabilitiesAccount.getUuid()).thenReturn(AuthHelper.VALID_UUID); + when(capabilitiesAccount.getIdentifier(IdentityType.ACI)).thenReturn(AuthHelper.VALID_UUID); when(capabilitiesAccount.getIdentityKey(IdentityType.ACI)).thenReturn(ACCOUNT_IDENTITY_KEY); when(capabilitiesAccount.getIdentityKey(IdentityType.PNI)).thenReturn(ACCOUNT_PHONE_NUMBER_IDENTITY_KEY); @@ -1010,6 +1012,7 @@ class ProfileControllerTest { void testGetProfileWithExpiringProfileKeyCredentialVersionNotFound() throws VerificationFailedException { final Account account = mock(Account.class); when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(AuthHelper.VALID_UUID); when(account.getCurrentProfileVersion()).thenReturn(Optional.of(versionHex("version"))); when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account)); @@ -1218,6 +1221,7 @@ class ProfileControllerTest { final Account account = mock(Account.class); when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(AuthHelper.VALID_UUID); when(account.getCurrentProfileVersion()).thenReturn(Optional.of(version)); when(account.getUnidentifiedAccessKey()).thenReturn(Optional.of(UNIDENTIFIED_ACCESS_KEY)); when(account.isIdentifiedBy(new AciServiceIdentifier(AuthHelper.VALID_UUID))).thenReturn(true); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java index ebf26a7f5..ab73aeef5 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java @@ -134,7 +134,7 @@ class RegistrationControllerTest { void setUp() { when(rateLimiters.getRegistrationLimiter()).thenReturn(registrationLimiter); - when(accountsManager.update(any(), any())).thenAnswer(invocation -> { + when(accountsManager.update(any(UUID.class), any())).thenAnswer(invocation -> { final Account account = invocation.getArgument(0); final Consumer accountUpdater = invocation.getArgument(1); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/AccountsGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/AccountsGrpcServiceTest.java index ed28e6603..799e97618 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/AccountsGrpcServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/AccountsGrpcServiceTest.java @@ -27,7 +27,6 @@ import java.util.List; import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; -import java.util.function.Consumer; import java.util.stream.Stream; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -77,6 +76,7 @@ import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager; import org.whispersystems.textsecuregcm.storage.UsernameHashNotAvailableException; import org.whispersystems.textsecuregcm.storage.UsernameReservationNotFoundException; +import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.UUIDUtil; import org.whispersystems.textsecuregcm.util.UsernameHashZkProofVerifier; @@ -97,15 +97,7 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest { - final Account account = invocation.getArgument(0); - final Consumer updater = invocation.getArgument(1); - - updater.accept(account); - - return account; - }); + AccountsHelper.setupMockUpdate(accountsManager); final RateLimiters rateLimiters = mock(RateLimiters.class); when(rateLimiters.getUsernameReserveLimiter()).thenReturn(rateLimiter); @@ -154,15 +146,10 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest authenticatedServiceStub().setRegistrationLock(SetRegistrationLockRequest.newBuilder() .build())); - verify(accountsManager, never()).update(any(), any()); + verify(accountsManager, never()).update(any(UUID.class), any()); } @Test @@ -219,7 +206,7 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest authenticatedServiceStub().clearRegistrationLock(ClearRegistrationLockRequest.newBuilder().build())); - verify(accountsManager, never()).update(any(), any()); + verify(accountsManager, never()).update(any(UUID.class), any()); } @Test void reserveUsernameHash() throws UsernameHashNotAvailableException { final Account account = mock(Account.class); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(AUTHENTICATED_ACI); when(accountsManager.getByAccountIdentifier(AUTHENTICATED_ACI)) .thenReturn(Optional.of(account)); @@ -259,7 +247,7 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest { final List usernameHashes = invocation.getArgument(1); - return new AccountsManager.UsernameReservation(invocation.getArgument(0), usernameHashes.getFirst()); + return new AccountsManager.UsernameReservation(accountsManager.getByAccountIdentifier(invocation.getArgument(0)).orElseThrow(), usernameHashes.getFirst()); }); final ReserveUsernameHashResponse expectedResponse = ReserveUsernameHashResponse.newBuilder() @@ -363,7 +351,7 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest { final Account updatedAccount = mock(Account.class); @@ -502,12 +490,12 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest authenticatedServiceStub().deleteUsernameHash(DeleteUsernameHashRequest.newBuilder().build())); - verify(accountsManager).clearUsernameHash(account); + verify(accountsManager).clearUsernameHash(AUTHENTICATED_ACI); } @ParameterizedTest @@ -515,6 +503,7 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest { @@ -66,6 +66,7 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest { - final Account account = invocation.getArgument(0); - final Consumer updater = invocation.getArgument(1); - - updater.accept(account); - - return account; - }); - - when(accountsManager.updateDevice(any(), anyByte(), any())) - .thenAnswer(invocation -> { - final Account account = invocation.getArgument(0); - final Device device = account.getDevice(invocation.getArgument(1)).orElseThrow(); - final Consumer updater = invocation.getArgument(2); - - updater.accept(device); - - return account; - }); + AccountsHelper.setupMockUpdate(accountsManager); return new DevicesGrpcService(accountsManager); } @@ -154,7 +136,7 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest + disconnectedAccount.getIdentifier(IdentityType.ACI).equals(account.getIdentifier(IdentityType.ACI)))); } @SuppressWarnings("OptionalUsedAsFieldOrParameterType") diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java index f1d9037b7..74ee4ee29 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java @@ -169,7 +169,7 @@ class AccountsManagerChangeNumberIntegrationTest { final ECKeyPair pniIdentityKeyPair = ECKeyPair.generate(); - accountsManager.changeNumber(account, + accountsManager.changeNumber(originalUuid, secondNumber, new IdentityKey(pniIdentityKeyPair.getPublicKey()), Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)), @@ -197,7 +197,7 @@ class AccountsManagerChangeNumberIntegrationTest { final ECKeyPair pniIdentityKeyPair = ECKeyPair.generate(); - accountsManager.changeNumber(account, + accountsManager.changeNumber(originalUuid, originalNumber, new IdentityKey(pniIdentityKeyPair.getPublicKey()), Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)), @@ -235,7 +235,7 @@ class AccountsManagerChangeNumberIntegrationTest { final Map kemSignedPreKeys = Map.of(Device.PRIMARY_ID, rotatedKemSignedPreKey); final Map registrationIds = Map.of(Device.PRIMARY_ID, rotatedPniRegistrationId); - final Account updatedAccount = accountsManager.changeNumber(account, secondNumber, pniIdentityKey, preKeys, kemSignedPreKeys, registrationIds); + final Account updatedAccount = accountsManager.changeNumber(originalUuid, secondNumber, pniIdentityKey, preKeys, kemSignedPreKeys, registrationIds); final UUID secondPni = updatedAccount.getPhoneNumberIdentifier(); assertTrue(accountsManager.getByE164(originalNumber).isEmpty()); @@ -270,7 +270,7 @@ class AccountsManagerChangeNumberIntegrationTest { final ECKeyPair originalIdentityKeyPair = ECKeyPair.generate(); final ECKeyPair secondIdentityKeyPair = ECKeyPair.generate(); - account = accountsManager.changeNumber(account, + account = accountsManager.changeNumber(originalUuid, secondNumber, new IdentityKey(secondIdentityKeyPair.getPublicKey()), Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, secondIdentityKeyPair)), @@ -279,7 +279,7 @@ class AccountsManagerChangeNumberIntegrationTest { final UUID secondPni = account.getPhoneNumberIdentifier(); - accountsManager.changeNumber(account, + accountsManager.changeNumber(originalUuid, originalNumber, new IdentityKey(originalIdentityKeyPair.getPublicKey()), Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(3, originalIdentityKeyPair)), @@ -315,7 +315,7 @@ class AccountsManagerChangeNumberIntegrationTest { final UUID existingAccountUuid = existingAccount.getUuid(); - accountsManager.changeNumber(account, + accountsManager.changeNumber(originalUuid, secondNumber, new IdentityKey(secondIdentityKeyPair.getPublicKey()), Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, secondIdentityKeyPair)), @@ -337,7 +337,7 @@ class AccountsManagerChangeNumberIntegrationTest { assertEquals(Optional.of(existingAccountUuid), accountsManager.findRecentlyDeletedAccountIdentifier(originalPni)); assertEquals(Optional.empty(), accountsManager.findRecentlyDeletedAccountIdentifier(secondPni)); - accountsManager.changeNumber(accountsManager.getByAccountIdentifier(originalUuid).orElseThrow(), + accountsManager.changeNumber(originalUuid, originalNumber, new IdentityKey(originalIdentityKeyPair.getPublicKey()), Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, originalIdentityKeyPair)), @@ -364,7 +364,7 @@ class AccountsManagerChangeNumberIntegrationTest { final UUID existingAccountUuid = existingAccount.getUuid(); final ECKeyPair pniIdentityKeyPair = ECKeyPair.generate(); - final Account changedNumberAccount = accountsManager.changeNumber(account, + final Account changedNumberAccount = accountsManager.changeNumber(originalUuid, secondNumber, new IdentityKey(pniIdentityKeyPair.getPublicKey()), Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)), @@ -383,7 +383,7 @@ class AccountsManagerChangeNumberIntegrationTest { final ECKeyPair reRegisteredPniIdentityKeyPair = ECKeyPair.generate(); - final Account changedNumberReRegisteredAccount = accountsManager.changeNumber(reRegisteredAccount, + final Account changedNumberReRegisteredAccount = accountsManager.changeNumber(reRegisteredAccount.getIdentifier(IdentityType.ACI), secondNumber, new IdentityKey(reRegisteredPniIdentityKeyPair.getPublicKey()), Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, reRegisteredPniIdentityKeyPair)), diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java index 70b0c7328..e479dfcc2 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java @@ -159,7 +159,7 @@ class AccountsManagerConcurrentModificationIntegrationTest { KeysHelper.signedECPreKey(2, pniKeyPair), KeysHelper.signedKEMPreKey(3, aciKeyPair), KeysHelper.signedKEMPreKey(4, pniKeyPair)), - null), + null).getIdentifier(IdentityType.ACI), a -> { a.setUnidentifiedAccessKey(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); a.removeDevice(Device.PRIMARY_ID); @@ -226,18 +226,10 @@ class AccountsManagerConcurrentModificationIntegrationTest { } private CompletableFuture modifyAccount(final UUID uuid, final Consumer accountMutation) { - - return CompletableFuture.runAsync(() -> { - final Account account = accountsManager.getByAccountIdentifier(uuid).orElseThrow(); - accountsManager.update(account, accountMutation); - }, mutationExecutor); + return CompletableFuture.runAsync(() -> accountsManager.update(uuid, accountMutation), mutationExecutor); } private CompletableFuture modifyDevice(final UUID uuid, final byte deviceId, final Consumer deviceMutation) { - - return CompletableFuture.runAsync(() -> { - final Account account = accountsManager.getByAccountIdentifier(uuid).orElseThrow(); - accountsManager.updateDevice(account, deviceId, deviceMutation); - }, mutationExecutor); + return CompletableFuture.runAsync(() -> accountsManager.updateDevice(uuid, deviceId, deviceMutation), mutationExecutor); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java index ffb417dd4..42a4d5900 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java @@ -603,6 +603,7 @@ class AccountsManagerTest { UUID uuid = UUID.randomUUID(); UUID pni = UUID.randomUUID(); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); + addRetrievableAccount(account); when(clusterCommands.get(eq("Account3::" + uuid))).thenReturn(null); @@ -614,40 +615,21 @@ class AccountsManagerTest { final IdentityKey identityKey = new IdentityKey(ECKeyPair.generate().getPublicKey()); - account = accountsManager.update(account, a -> a.setIdentityKey(identityKey)); + account = accountsManager.update(uuid, a -> a.setIdentityKey(identityKey)); assertEquals(1, account.getVersion()); assertEquals(identityKey, account.getIdentityKey(IdentityType.ACI)); - verify(accounts, times(1)).getByAccountIdentifier(uuid); + verify(accounts, times(2)).getByAccountIdentifier(uuid); verify(accounts, times(2)).update(any()); verifyNoMoreInteractions(accounts); } - @Test - void testUpdate_dynamoOptimisticLockingFailureDuringCreate() throws AccountAlreadyExistsException { - UUID uuid = UUID.randomUUID(); - Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); - - when(clusterCommands.get(eq("Account3::" + uuid))).thenReturn(null); - when(accounts.getByAccountIdentifier(uuid)).thenReturn(Optional.empty()) - .thenReturn(Optional.of(account)); - when(accounts.create(any(), any())).thenThrow(ContestedOptimisticLockException.class); - - accountsManager.update(account, a -> { - }); - - verify(accounts, times(1)).update(account); - verifyNoMoreInteractions(accounts); - } - @Test void testUpdateDevice() { final UUID uuid = UUID.randomUUID(); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); - - when(accounts.getByAccountIdentifier(uuid)).thenReturn( - Optional.of(AccountsHelper.generateTestAccount("+14152222222", uuid, UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]))); + addRetrievableAccount(account); assertTrue(account.getDevices().isEmpty()); @@ -661,14 +643,14 @@ class AccountsManagerTest { @SuppressWarnings("unchecked") Consumer deviceUpdater = mock(Consumer.class); @SuppressWarnings("unchecked") Consumer unknownDeviceUpdater = mock(Consumer.class); - account = accountsManager.updateDevice(account, deviceId, deviceUpdater); - account = accountsManager.updateDevice(account, deviceId, d -> d.setName("deviceName".getBytes(StandardCharsets.UTF_8))); + account = accountsManager.updateDevice(uuid, deviceId, deviceUpdater); + account = accountsManager.updateDevice(uuid, deviceId, d -> d.setName("deviceName".getBytes(StandardCharsets.UTF_8))); assertArrayEquals("deviceName".getBytes(StandardCharsets.UTF_8), account.getDevice(deviceId).orElseThrow().getName()); verify(deviceUpdater, times(1)).accept(any(Device.class)); - accountsManager.updateDevice(account, account.getNextDeviceId(), unknownDeviceUpdater); + accountsManager.updateDevice(uuid, account.getNextDeviceId(), unknownDeviceUpdater); verify(unknownDeviceUpdater, never()).accept(any(Device.class)); } @@ -689,7 +671,7 @@ class AccountsManagerTest { assertTrue(account.getDevice(linkedDevice.getId()).isPresent()); - account = accountsManager.removeDevice(account, linkedDevice.getId()); + account = accountsManager.removeDevice(account.getIdentifier(IdentityType.ACI), linkedDevice.getId()); assertFalse(account.getDevice(linkedDevice.getId()).isPresent()); verify(messagesManager, times(2)).clear(account.getUuid(), linkedDevice.getId()); @@ -708,7 +690,8 @@ class AccountsManagerTest { when(keysManager.deleteSingleUsePreKeys(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); when(messagesManager.clear(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); - assertThrows(IllegalArgumentException.class, () -> accountsManager.removeDevice(account, Device.PRIMARY_ID)); + assertThrows(IllegalArgumentException.class, + () -> accountsManager.removeDevice(account.getIdentifier(IdentityType.ACI), Device.PRIMARY_ID)); assertDoesNotThrow(account::getPrimaryDevice); verify(messagesManager, never()).clear(any(), anyByte()); @@ -876,7 +859,7 @@ class AccountsManagerTest { CLOCK.pin(CLOCK.instant().plusSeconds(60)); - final Pair updatedAccountAndDevice = accountsManager.addDevice(account, new DeviceSpec( + final Pair updatedAccountAndDevice = accountsManager.addDevice(aci, new DeviceSpec( deviceNameCiphertext, password, signalAgent, @@ -922,10 +905,12 @@ class AccountsManagerTest { @MethodSource void testUpdateDeviceLastSeen(final boolean expectUpdate, final long initialLastSeen, final long updatedLastSeen) { final Account account = AccountsHelper.generateTestAccount("+14152222222", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); + addRetrievableAccount(account); + final Device device = generateTestDevice(initialLastSeen); account.addDevice(device); - accountsManager.updateDeviceLastSeen(account, device, updatedLastSeen); + accountsManager.updateDeviceLastSeen(account.getIdentifier(IdentityType.ACI), device, updatedLastSeen); assertEquals(expectUpdate ? updatedLastSeen : initialLastSeen, device.getLastSeen()); verify(accounts, expectUpdate ? times(1) : never()).update(account); @@ -957,7 +942,8 @@ class AccountsManagerTest { final KEMSignedPreKey kemLastResortPreKey = KeysHelper.signedKEMPreKey(2, pniIdentityKeyPair); Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, List.of(DevicesHelper.createDevice(Device.PRIMARY_ID)), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); - account = accountsManager.changeNumber(account, + addRetrievableAccount(account); + account = accountsManager.changeNumber(uuid, targetNumber, new IdentityKey(pniIdentityKeyPair.getPublicKey()), Map.of(Device.PRIMARY_ID, ecSignedPreKey), @@ -985,9 +971,11 @@ class AccountsManagerTest { Account account = AccountsHelper.generateTestAccount(originalNumber, UUID.randomUUID(), phoneNumberIdentifier, List.of(DevicesHelper.createDevice(Device.PRIMARY_ID)), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); + addRetrievableAccount(account); + phoneNumberIdentifiersByE164.put(originalNumber, account.getPhoneNumberIdentifier()); phoneNumberIdentifiersByE164.put(newNumber, account.getPhoneNumberIdentifier()); - account = accountsManager.changeNumber(account, + account = accountsManager.changeNumber(account.getIdentifier(IdentityType.ACI), newNumber, new IdentityKey(pniIdentityKeyPair.getPublicKey()), Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)), @@ -1016,7 +1004,9 @@ class AccountsManagerTest { final KEMSignedPreKey kemLastResoryPreKey = KeysHelper.signedKEMPreKey(2, pniIdentityKeyPair); Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, List.of(DevicesHelper.createDevice(Device.PRIMARY_ID)), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); - account = accountsManager.changeNumber(account, + addRetrievableAccount(account); + + account = accountsManager.changeNumber(uuid, targetNumber, new IdentityKey(pniIdentityKeyPair.getPublicKey()), Map.of(Device.PRIMARY_ID, ecSignedPreKey), @@ -1064,8 +1054,9 @@ class AccountsManagerTest { DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101), DevicesHelper.createDevice(deviceId2, 0L, 102)); final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); + addRetrievableAccount(account); final Account updatedAccount = accountsManager.changeNumber( - account, targetNumber, new IdentityKey(ECKeyPair.generate().getPublicKey()), newSignedKeys, newSignedPqKeys, newRegistrationIds); + uuid, targetNumber, new IdentityKey(ECKeyPair.generate().getPublicKey()), newSignedKeys, newSignedPqKeys, newRegistrationIds); assertEquals(targetNumber, updatedAccount.getNumber()); @@ -1102,11 +1093,12 @@ class AccountsManagerTest { final List devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101), DevicesHelper.createDevice(deviceId2, 0L, 102)); final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); + addRetrievableAccount(account); assertThrows(MismatchedDevicesException.class, () -> accountsManager.changeNumber( - account, targetNumber, new IdentityKey(ECKeyPair.generate().getPublicKey()), newSignedKeys, newSignedPqKeys, newRegistrationIds)); + uuid, targetNumber, new IdentityKey(ECKeyPair.generate().getPublicKey()), newSignedKeys, newSignedPqKeys, newRegistrationIds)); - verifyNoInteractions(accounts); + verify(accounts, never()).changeNumber(any(), any(), any(), any(), any()); verifyNoInteractions(keysManager); } @@ -1117,8 +1109,9 @@ class AccountsManagerTest { final UUID uuid = UUID.randomUUID(); final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); + addRetrievableAccount(account); - assertThrows(AssertionError.class, () -> accountsManager.update(account, a -> a.setNumber(targetNumber, UUID.randomUUID()))); + assertThrows(AssertionError.class, () -> accountsManager.update(uuid, a -> a.setNumber(targetNumber, UUID.randomUUID()))); } @Test @@ -1128,7 +1121,7 @@ class AccountsManagerTest { final List usernameHashes = List.of(TestRandomUtil.nextBytes(32), TestRandomUtil.nextBytes(32)); - final UsernameReservation result = accountsManager.reserveUsernameHash(account, usernameHashes); + final UsernameReservation result = accountsManager.reserveUsernameHash(account.getIdentifier(IdentityType.ACI), usernameHashes); assertArrayEquals(usernameHashes.getFirst(), result.reservedUsernameHash()); verify(accounts, times(1)).reserveUsernameHash(eq(account), any(), eq(Duration.ofMinutes(5))); } @@ -1142,7 +1135,7 @@ class AccountsManagerTest { final List usernameHashes = List.of(TestRandomUtil.nextBytes(32), oldUsernameHash, TestRandomUtil.nextBytes(32)); - final UsernameReservation result = accountsManager.reserveUsernameHash(account, usernameHashes); + final UsernameReservation result = accountsManager.reserveUsernameHash(account.getIdentifier(IdentityType.ACI), usernameHashes); assertArrayEquals(oldUsernameHash, result.reservedUsernameHash()); verify(accounts, never()).reserveUsernameHash(any(), any(), any()); } @@ -1158,7 +1151,7 @@ class AccountsManagerTest { .doNothing() .when(accounts).reserveUsernameHash(any(), any(), any()); - final UsernameReservation result = accountsManager.reserveUsernameHash(account, usernameHashes); + final UsernameReservation result = accountsManager.reserveUsernameHash(account.getIdentifier(IdentityType.ACI), usernameHashes); assertArrayEquals(usernameHashes.getFirst(), result.reservedUsernameHash()); verify(accounts, times(2)).reserveUsernameHash(eq(account), any(), eq(Duration.ofMinutes(5))); } @@ -1166,21 +1159,23 @@ class AccountsManagerTest { @Test void testReserveUsernameHashAsyncNotAvailable() throws UsernameHashNotAvailableException { final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); - when(accounts.getByAccountIdentifierAsync(account.getUuid())).thenReturn(CompletableFuture.completedFuture(Optional.of(account))); + addRetrievableAccount(account); doThrow(new UsernameHashNotAvailableException()) .when(accounts).reserveUsernameHash(any(), any(), any()); assertThrows(UsernameHashNotAvailableException.class, () -> - accountsManager.reserveUsernameHash(account, List.of(USERNAME_HASH_1, USERNAME_HASH_2))); + accountsManager.reserveUsernameHash(account.getIdentifier(IdentityType.ACI), List.of(USERNAME_HASH_1, USERNAME_HASH_2))); } @Test void testConfirmReservedUsernameHash() throws UsernameHashNotAvailableException, UsernameReservationNotFoundException { final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); + addRetrievableAccount(account); + setReservationHash(account, USERNAME_HASH_1); - accountsManager.confirmReservedUsernameHash(account, USERNAME_HASH_1, ENCRYPTED_USERNAME_1); + accountsManager.confirmReservedUsernameHash(account.getIdentifier(IdentityType.ACI), USERNAME_HASH_1, ENCRYPTED_USERNAME_1); verify(accounts).confirmUsernameHash(eq(account), eq(USERNAME_HASH_1), eq(ENCRYPTED_USERNAME_1)); } @@ -1194,62 +1189,72 @@ class AccountsManagerTest { .doNothing() .when(accounts).confirmUsernameHash(account, USERNAME_HASH_1, ENCRYPTED_USERNAME_1); - accountsManager.confirmReservedUsernameHash(account, USERNAME_HASH_1, ENCRYPTED_USERNAME_1); + accountsManager.confirmReservedUsernameHash(account.getIdentifier(IdentityType.ACI), USERNAME_HASH_1, ENCRYPTED_USERNAME_1); verify(accounts, times(2)).confirmUsernameHash(eq(account), eq(USERNAME_HASH_1), eq(ENCRYPTED_USERNAME_1)); } @Test void testConfirmReservedHashNameMismatch() { final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); + addRetrievableAccount(account); + setReservationHash(account, USERNAME_HASH_1); assertThrows(UsernameReservationNotFoundException.class, - () -> accountsManager.confirmReservedUsernameHash(account, USERNAME_HASH_2, ENCRYPTED_USERNAME_2)); + () -> accountsManager.confirmReservedUsernameHash(account.getIdentifier(IdentityType.ACI), USERNAME_HASH_2, ENCRYPTED_USERNAME_2)); } @Test void testConfirmReservedLapsed() throws UsernameHashNotAvailableException { final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); + addRetrievableAccount(account); // hash was reserved, but the reservation lapsed and another account took it setReservationHash(account, USERNAME_HASH_1); doThrow(new UsernameHashNotAvailableException()) .when(accounts).confirmUsernameHash(account, USERNAME_HASH_1, ENCRYPTED_USERNAME_1); assertThrows(UsernameHashNotAvailableException.class, - () -> accountsManager.confirmReservedUsernameHash(account, USERNAME_HASH_1, ENCRYPTED_USERNAME_1)); + () -> accountsManager.confirmReservedUsernameHash(account.getIdentifier(IdentityType.ACI), USERNAME_HASH_1, ENCRYPTED_USERNAME_1)); assertTrue(account.getUsernameHash().isEmpty()); } @Test void testConfirmReservedRetry() throws UsernameHashNotAvailableException, UsernameReservationNotFoundException { final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); + addRetrievableAccount(account); account.setUsernameHash(USERNAME_HASH_1); // reserved username already set, should be treated as a replay - accountsManager.confirmReservedUsernameHash(account, USERNAME_HASH_1, ENCRYPTED_USERNAME_1); - verifyNoInteractions(accounts); + accountsManager.confirmReservedUsernameHash(account.getIdentifier(IdentityType.ACI), USERNAME_HASH_1, ENCRYPTED_USERNAME_1); + verify(accounts, never()).confirmUsernameHash(any(), any(), any()); } @Test void testConfirmReservedUsernameHashWithNoReservation() throws UsernameHashNotAvailableException { final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); + addRetrievableAccount(account); + assertThrows(UsernameReservationNotFoundException.class, - () -> accountsManager.confirmReservedUsernameHash(account, USERNAME_HASH_1, ENCRYPTED_USERNAME_1)); + () -> accountsManager.confirmReservedUsernameHash(account.getIdentifier(IdentityType.ACI), USERNAME_HASH_1, ENCRYPTED_USERNAME_1)); verify(accounts, never()).confirmUsernameHash(any(), any(), any()); } @Test void testClearUsernameHash() { final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); + addRetrievableAccount(account); + account.setUsernameHash(USERNAME_HASH_1); - accountsManager.clearUsernameHash(account); + accountsManager.clearUsernameHash(account.getIdentifier(IdentityType.ACI)); verify(accounts).clearUsernameHash(eq(account)); } @Test void testSetUsernameViaUpdate() { final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); + addRetrievableAccount(account); - assertThrows(AssertionError.class, () -> accountsManager.update(account, a -> a.setUsernameHash(USERNAME_HASH_1))); + assertThrows(AssertionError.class, () -> + accountsManager.update(account.getIdentifier(IdentityType.ACI), a -> a.setUsernameHash(USERNAME_HASH_1))); } @Test @@ -1440,4 +1445,12 @@ class AccountsManagerTest { new MismatchedDevices(Set.of(deviceId), Set.of((byte) (extraDeviceId)), Collections.emptySet()))) ); } + + private void addRetrievableAccount(final Account account) { + when(accounts.getByAccountIdentifier(account.getIdentifier(IdentityType.ACI))) + .thenReturn(Optional.of(account)); + + when(accounts.getByAccountIdentifierAsync(account.getIdentifier(IdentityType.ACI))) + .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java index 199de8bbf..846a092de 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java @@ -37,6 +37,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.mockito.Mockito; import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager; +import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; @@ -176,7 +177,7 @@ class AccountsManagerUsernameIntegrationTest { } assertThrows(UsernameHashNotAvailableException.class, - () -> accountsManager.reserveUsernameHash(account, usernameHashes)); + () -> accountsManager.reserveUsernameHash(account.getIdentifier(IdentityType.ACI), usernameHashes)); assertThat(accountsManager.getByAccountIdentifier(account.getUuid()).orElseThrow().getUsernameHash()).isEmpty(); } @@ -201,7 +202,7 @@ class AccountsManagerUsernameIntegrationTest { usernameHashes.add(TestRandomUtil.nextBytes(32)); final byte[] username = accountsManager - .reserveUsernameHash(account, usernameHashes) + .reserveUsernameHash(account.getIdentifier(IdentityType.ACI), usernameHashes) .reservedUsernameHash(); assertArrayEquals(username, availableHash); @@ -214,14 +215,14 @@ class AccountsManagerUsernameIntegrationTest { // reserve AccountsManager.UsernameReservation reservation = - accountsManager.reserveUsernameHash(account, List.of(USERNAME_HASH_1)); + accountsManager.reserveUsernameHash(account.getIdentifier(IdentityType.ACI), List.of(USERNAME_HASH_1)); assertArrayEquals(reservation.account().getReservedUsernameHash().orElseThrow(), USERNAME_HASH_1); assertThat(accountsManager.getByUsernameHash(reservation.reservedUsernameHash()).join()).isEmpty(); // confirm account = accountsManager.confirmReservedUsernameHash( - reservation.account(), + reservation.account().getIdentifier(IdentityType.ACI), reservation.reservedUsernameHash(), ENCRYPTED_USERNAME_1); assertArrayEquals(account.getUsernameHash().orElseThrow(), USERNAME_HASH_1); @@ -232,7 +233,7 @@ class AccountsManagerUsernameIntegrationTest { .isEqualTo(account.getUuid()); // clear - account = accountsManager.clearUsernameHash(account); + account = accountsManager.clearUsernameHash(account.getIdentifier(IdentityType.ACI)); assertThat(accountsManager.getByUsernameHash(USERNAME_HASH_1).join()).isEmpty(); assertThat(accountsManager.getByAccountIdentifier(account.getUuid()).orElseThrow().getUsernameHash()).isEmpty(); } @@ -243,16 +244,16 @@ class AccountsManagerUsernameIntegrationTest { Account account = AccountsHelper.createAccount(accountsManager, "+18005551111"); AccountsManager.UsernameReservation reservation = - accountsManager.reserveUsernameHash(account, List.of(USERNAME_HASH_1)); + accountsManager.reserveUsernameHash(account.getIdentifier(IdentityType.ACI), List.of(USERNAME_HASH_1)); // confirm account = accountsManager.confirmReservedUsernameHash( - reservation.account(), + reservation.account().getIdentifier(IdentityType.ACI), reservation.reservedUsernameHash(), ENCRYPTED_USERNAME_1); // clear - account = accountsManager.clearUsernameHash(account); + account = accountsManager.clearUsernameHash(account.getIdentifier(IdentityType.ACI)); assertThat(accountsManager.getByUsernameHash(USERNAME_HASH_1).join()).isEmpty(); assertThat(accountsManager.getByAccountIdentifier(account.getUuid()).orElseThrow().getUsernameHash()).isEmpty(); @@ -260,7 +261,7 @@ class AccountsManagerUsernameIntegrationTest { Account account2 = AccountsHelper.createAccount(accountsManager, "+18005552222"); assertThrows(UsernameHashNotAvailableException.class, - () -> accountsManager.reserveUsernameHash(account2, List.of(USERNAME_HASH_1)), + () -> accountsManager.reserveUsernameHash(account2.getIdentifier(IdentityType.ACI), List.of(USERNAME_HASH_1)), "account2 should not be able to reserve a held hash"); } @@ -270,7 +271,7 @@ class AccountsManagerUsernameIntegrationTest { final Account account = AccountsHelper.createAccount(accountsManager, "+18005551111"); AccountsManager.UsernameReservation reservation1 = - accountsManager.reserveUsernameHash(account, List.of(USERNAME_HASH_1)); + accountsManager.reserveUsernameHash(account.getIdentifier(IdentityType.ACI), List.of(USERNAME_HASH_1)); long past = Instant.now().minus(Duration.ofMinutes(1)).getEpochSecond(); // force expiration @@ -286,12 +287,12 @@ class AccountsManagerUsernameIntegrationTest { Account account2 = AccountsHelper.createAccount(accountsManager, "+18005552222"); final AccountsManager.UsernameReservation reservation2 = - accountsManager.reserveUsernameHash(account2, List.of(USERNAME_HASH_1)); + accountsManager.reserveUsernameHash(account2.getIdentifier(IdentityType.ACI), List.of(USERNAME_HASH_1)); assertArrayEquals(reservation2.reservedUsernameHash(), USERNAME_HASH_1); assertThrows(UsernameHashNotAvailableException.class, - () -> accountsManager.confirmReservedUsernameHash(reservation1.account(), USERNAME_HASH_1, ENCRYPTED_USERNAME_1)); - account2 = accountsManager.confirmReservedUsernameHash(reservation2.account(), USERNAME_HASH_1, ENCRYPTED_USERNAME_1); + () -> accountsManager.confirmReservedUsernameHash(reservation1.account().getIdentifier(IdentityType.ACI), USERNAME_HASH_1, ENCRYPTED_USERNAME_1)); + account2 = accountsManager.confirmReservedUsernameHash(reservation2.account().getIdentifier(IdentityType.ACI), USERNAME_HASH_1, ENCRYPTED_USERNAME_1); assertEquals(accountsManager.getByUsernameHash(USERNAME_HASH_1).join().orElseThrow().getUuid(), account2.getUuid()); assertArrayEquals(account2.getUsernameHash().orElseThrow(), USERNAME_HASH_1); } @@ -303,13 +304,13 @@ class AccountsManagerUsernameIntegrationTest { // Set username hash final AccountsManager.UsernameReservation reservation1 = - accountsManager.reserveUsernameHash(account, List.of(USERNAME_HASH_1)); + accountsManager.reserveUsernameHash(account.getIdentifier(IdentityType.ACI), List.of(USERNAME_HASH_1)); - account = accountsManager.confirmReservedUsernameHash(reservation1.account(), USERNAME_HASH_1, ENCRYPTED_USERNAME_1); + account = accountsManager.confirmReservedUsernameHash(reservation1.account().getIdentifier(IdentityType.ACI), USERNAME_HASH_1, ENCRYPTED_USERNAME_1); // Reserve another hash on the same account final AccountsManager.UsernameReservation reservation2 = - accountsManager.reserveUsernameHash(account, List.of(USERNAME_HASH_2)); + accountsManager.reserveUsernameHash(account.getIdentifier(IdentityType.ACI), List.of(USERNAME_HASH_2)); account = reservation2.account(); @@ -318,12 +319,12 @@ class AccountsManagerUsernameIntegrationTest { assertArrayEquals(account.getEncryptedUsername().orElseThrow(), ENCRYPTED_USERNAME_1); // Clear the set username hash but not the reserved one - account = accountsManager.clearUsernameHash(account); + account = accountsManager.clearUsernameHash(account.getIdentifier(IdentityType.ACI)); assertThat(account.getReservedUsernameHash()).isPresent(); assertThat(account.getUsernameHash()).isEmpty(); // Confirm second reservation - account = accountsManager.confirmReservedUsernameHash(account, reservation2.reservedUsernameHash(), ENCRYPTED_USERNAME_2); + account = accountsManager.confirmReservedUsernameHash(account.getIdentifier(IdentityType.ACI), reservation2.reservedUsernameHash(), ENCRYPTED_USERNAME_2); assertArrayEquals(account.getUsernameHash().orElseThrow(), USERNAME_HASH_2); assertArrayEquals(account.getEncryptedUsername().orElseThrow(), ENCRYPTED_USERNAME_2); } @@ -333,8 +334,8 @@ class AccountsManagerUsernameIntegrationTest { throws InterruptedException, UsernameHashNotAvailableException, UsernameReservationNotFoundException { Account account = AccountsHelper.createAccount(accountsManager, "+18005551111"); final AccountsManager.UsernameReservation reservation1 = - accountsManager.reserveUsernameHash(account, List.of(USERNAME_HASH_1)); - account = accountsManager.confirmReservedUsernameHash(reservation1.account(), USERNAME_HASH_1, ENCRYPTED_USERNAME_1); + accountsManager.reserveUsernameHash(account.getIdentifier(IdentityType.ACI), List.of(USERNAME_HASH_1)); + account = accountsManager.confirmReservedUsernameHash(reservation1.account().getIdentifier(IdentityType.ACI), USERNAME_HASH_1, ENCRYPTED_USERNAME_1); // "reclaim" the account by re-registering Account reclaimed = AccountsHelper.createAccount(accountsManager, "+18005551111"); @@ -346,20 +347,29 @@ class AccountsManagerUsernameIntegrationTest { assertThat(accountsManager.getByUsernameHash(USERNAME_HASH_1).join()).isEmpty(); // confirm it again - accountsManager.confirmReservedUsernameHash(reclaimed, USERNAME_HASH_1, ENCRYPTED_USERNAME_1); + accountsManager.confirmReservedUsernameHash(reclaimed.getIdentifier(IdentityType.ACI), USERNAME_HASH_1, ENCRYPTED_USERNAME_1); assertThat(accountsManager.getByUsernameHash(USERNAME_HASH_1).join()).isPresent(); } @Test - public void testUsernameLinks() throws InterruptedException, AccountAlreadyExistsException { - final Account account = AccountsHelper.createAccount(accountsManager, "+18005551111"); + public void testUsernameLinks() + throws InterruptedException, AccountAlreadyExistsException, UsernameHashNotAvailableException, UsernameReservationNotFoundException { + final UUID accountIdentifier; + { + final Account account = AccountsHelper.createAccount(accountsManager, "+18005551111"); + accounts.create(account, Collections.emptyList()); - account.setUsernameHash(TestRandomUtil.nextBytes(16)); - accounts.create(account, Collections.emptyList()); + accountIdentifier = account.getIdentifier(IdentityType.ACI); + } + + final AccountsManager.UsernameReservation reservation = + accountsManager.reserveUsernameHash(accountIdentifier, List.of(USERNAME_HASH_1)); + + accountsManager.confirmReservedUsernameHash(accountIdentifier, reservation.reservedUsernameHash(), ENCRYPTED_USERNAME_1); final UUID linkHandle = UUID.randomUUID(); final byte[] encryptedUsername = TestRandomUtil.nextBytes(32); - accountsManager.update(account, a -> a.setUsernameLinkDetails(linkHandle, encryptedUsername)); + accountsManager.update(accountIdentifier, account -> account.setUsernameLinkDetails(linkHandle, encryptedUsername)); final Optional maybeAccount = accountsManager.getByUsernameLinkHandle(linkHandle).join(); assertTrue(maybeAccount.isPresent()); @@ -367,17 +377,17 @@ class AccountsManagerUsernameIntegrationTest { assertArrayEquals(encryptedUsername, maybeAccount.get().getEncryptedUsername().get()); // making some unrelated change and updating account to check that username link data is still there - final Optional accountToChange = accountsManager.getByAccountIdentifier(account.getUuid()); + final Optional accountToChange = accountsManager.getByAccountIdentifier(accountIdentifier); assertTrue(accountToChange.isPresent()); - accountsManager.update(accountToChange.get(), a -> a.setDiscoverableByPhoneNumber(!a.isDiscoverableByPhoneNumber())); + accountsManager.update(accountToChange.get().getIdentifier(IdentityType.ACI), a -> a.setDiscoverableByPhoneNumber(!a.isDiscoverableByPhoneNumber())); final Optional accountAfterChange = accountsManager.getByUsernameLinkHandle(linkHandle).join(); assertTrue(accountAfterChange.isPresent()); assertTrue(accountAfterChange.get().getEncryptedUsername().isPresent()); assertArrayEquals(encryptedUsername, accountAfterChange.get().getEncryptedUsername().get()); // now deleting the link - final Optional accountToDeleteLink = accountsManager.getByAccountIdentifier(account.getUuid()); - accountsManager.update(accountToDeleteLink.orElseThrow(), a -> a.setUsernameLinkDetails(null, null)); + final Optional accountToDeleteLink = accountsManager.getByAccountIdentifier(accountIdentifier); + accountsManager.update(accountToDeleteLink.orElseThrow().getIdentifier(IdentityType.ACI), a -> a.setUsernameLinkDetails(null, null)); assertTrue(accounts.getByUsernameLinkHandle(linkHandle).join().isEmpty()); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java index e251a62d1..beda11000 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java @@ -184,7 +184,7 @@ public class AddRemoveDeviceIntegrationTest { assertEquals(1, accountsManager.getByAccountIdentifier(account.getUuid()).orElseThrow().getDevices().size()); final Pair updatedAccountAndDevice = - accountsManager.addDevice(account, new DeviceSpec( + accountsManager.addDevice(account.getIdentifier(IdentityType.ACI), new DeviceSpec( "device-name".getBytes(StandardCharsets.UTF_8), "password", "OWT", @@ -234,7 +234,7 @@ public class AddRemoveDeviceIntegrationTest { final String linkDeviceToken = accountsManager.generateLinkDeviceToken(account.getIdentifier(IdentityType.ACI)); final Pair updatedAccountAndDevice = - accountsManager.addDevice(account, new DeviceSpec( + accountsManager.addDevice(account.getIdentifier(IdentityType.ACI), new DeviceSpec( "device-name".getBytes(StandardCharsets.UTF_8), "password", "OWT", @@ -255,7 +255,7 @@ public class AddRemoveDeviceIntegrationTest { .size()); assertThrows(LinkDeviceTokenAlreadyUsedException.class, - () -> accountsManager.addDevice(account, new DeviceSpec( + () -> accountsManager.addDevice(account.getIdentifier(IdentityType.ACI), new DeviceSpec( "device-name".getBytes(StandardCharsets.UTF_8), "password", "OWT", @@ -289,7 +289,7 @@ public class AddRemoveDeviceIntegrationTest { assertEquals(1, accountsManager.getByAccountIdentifier(account.getUuid()).orElseThrow().getDevices().size()); final Pair updatedAccountAndDevice = - accountsManager.addDevice(account, new DeviceSpec( + accountsManager.addDevice(account.getIdentifier(IdentityType.ACI), new DeviceSpec( "device-name".getBytes(StandardCharsets.UTF_8), "password", "OWT", @@ -307,7 +307,7 @@ public class AddRemoveDeviceIntegrationTest { final byte addedDeviceId = updatedAccountAndDevice.second().getId(); - final Account updatedAccount = accountsManager.removeDevice(updatedAccountAndDevice.first(), addedDeviceId); + final Account updatedAccount = accountsManager.removeDevice(updatedAccountAndDevice.first().getIdentifier(IdentityType.ACI), addedDeviceId); assertEquals(1, updatedAccount.getDevices().size()); @@ -340,7 +340,7 @@ public class AddRemoveDeviceIntegrationTest { final UUID aci = account.getIdentifier(IdentityType.ACI); final Pair updatedAccountAndDevice = - accountsManager.addDevice(account, new DeviceSpec( + accountsManager.addDevice(account.getIdentifier(IdentityType.ACI), new DeviceSpec( "device-name".getBytes(StandardCharsets.UTF_8), "password", "OWT", @@ -362,7 +362,7 @@ public class AddRemoveDeviceIntegrationTest { .thenReturn(CompletableFuture.failedFuture(new RuntimeException("OH NO"))); assertThrows(RuntimeException.class, - () -> accountsManager.removeDevice(updatedAccountAndDevice.first(), addedDeviceId)); + () -> accountsManager.removeDevice(updatedAccountAndDevice.first().getIdentifier(IdentityType.ACI), addedDeviceId)); final Account retrievedAccount = accountsManager.getByAccountIdentifierAsync(aci).join().orElseThrow(); @@ -410,7 +410,7 @@ public class AddRemoveDeviceIntegrationTest { assertEquals(Optional.empty(), displacedFuture.join()); final Pair updatedAccountAndDevice = - accountsManager.addDevice(account, new DeviceSpec( + accountsManager.addDevice(account.getIdentifier(IdentityType.ACI), new DeviceSpec( "device-name".getBytes(StandardCharsets.UTF_8), "password", "OWT", @@ -451,7 +451,7 @@ public class AddRemoveDeviceIntegrationTest { final String linkDeviceTokenIdentifier = AccountsManager.getLinkDeviceTokenIdentifier(linkDeviceToken); final Pair updatedAccountAndDevice = - accountsManager.addDevice(account, new DeviceSpec( + accountsManager.addDevice(account.getIdentifier(IdentityType.ACI), new DeviceSpec( "device-name".getBytes(StandardCharsets.UTF_8), "password", "OWT", @@ -520,7 +520,7 @@ public class AddRemoveDeviceIntegrationTest { final String linkDeviceToken = accountsManager.generateLinkDeviceToken(UUID.randomUUID()); final String linkDeviceTokenIdentifier = AccountsManager.getLinkDeviceTokenIdentifier(linkDeviceToken); - accountsManager.addDevice(account, new DeviceSpec( + accountsManager.addDevice(account.getIdentifier(IdentityType.ACI), new DeviceSpec( "device-name".getBytes(StandardCharsets.UTF_8), "password", "OWT", @@ -563,7 +563,7 @@ public class AddRemoveDeviceIntegrationTest { final String linkDeviceTokenIdentifier = AccountsManager.getLinkDeviceTokenIdentifier(linkDeviceToken); clock.pin(Instant.ofEpochMilli(0)); - accountsManager.addDevice(account, new DeviceSpec( + accountsManager.addDevice(account.getIdentifier(IdentityType.ACI), new DeviceSpec( "device-name".getBytes(StandardCharsets.UTF_8), "password", "OWT", diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java index 6704af13c..a31912be0 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java @@ -35,6 +35,7 @@ import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; import org.whispersystems.textsecuregcm.push.MessageSender; +import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.util.TestClock; @@ -56,7 +57,7 @@ public class ChangeNumberManagerTest { updatedPhoneNumberIdentifiersByAccount = new HashMap<>(); when(accountsManager.changeNumber(any(), any(), any(), any(), any(), any())).thenAnswer((Answer)invocation -> { - final Account account = invocation.getArgument(0, Account.class); + final Account account = accountsManager.getByAccountIdentifier(invocation.getArgument(0)).orElseThrow(); final String number = invocation.getArgument(1, String.class); final UUID uuid = account.getIdentifier(IdentityType.ACI); @@ -99,11 +100,14 @@ public class ChangeNumberManagerTest { final UUID accountIdentifier = UUID.randomUUID(); final Account account = mock(Account.class); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier); when(account.isIdentifiedBy(any())).thenReturn(false); when(account.isIdentifiedBy(new AciServiceIdentifier(accountIdentifier))).thenReturn(true); + AccountsHelper.setupMockGet(accountsManager, account); + changeNumberManager.changeNumber(account, targetNumber, pniIdentityKey, ecSignedPreKeys, kemLastResortPreKeys, Collections.emptyList(), Collections.emptyMap(), null); - verify(accountsManager).changeNumber(account, targetNumber, pniIdentityKey, ecSignedPreKeys, kemLastResortPreKeys, Collections.emptyMap()); + verify(accountsManager).changeNumber(accountIdentifier, targetNumber, pniIdentityKey, ecSignedPreKeys, kemLastResortPreKeys, Collections.emptyMap()); verify(messageSender, never()).sendMessages(eq(account), any(), any(), any(), any(), any()); } @@ -154,6 +158,8 @@ public class ChangeNumberManagerTest { final IncomingMessage incomingMessage = new IncomingMessage(1, linkedDeviceId, linkedDeviceRegistrationId, new byte[] { 1 }); + AccountsHelper.setupMockGet(accountsManager, account); + changeNumberManager.changeNumber(account, targetNumber, pniIdentityKey, @@ -163,7 +169,7 @@ public class ChangeNumberManagerTest { registrationIds, null); - verify(accountsManager).changeNumber(account, + verify(accountsManager).changeNumber(aci, targetNumber, pniIdentityKey, ecSignedPreKeys, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java index 029631153..769deaa95 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java @@ -31,7 +31,6 @@ import java.util.concurrent.TimeUnit; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.RepeatedTest; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java index 4e2b0be4c..d1d2fbf49 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java @@ -115,7 +115,7 @@ class MessagePersisterTest { when(accountsManager.getByAccountIdentifierAsync(DESTINATION_ACCOUNT_UUID)) .thenReturn(CompletableFuture.completedFuture(Optional.of(destinationAccount))); when(accountsManager.removeDevice(any(), anyByte())) - .thenAnswer(invocation -> invocation.getArgument(0)); + .thenAnswer(invocation -> accountsManager.getByAccountIdentifier(invocation.getArgument(0)).orElseThrow()); when(destinationAccount.getUuid()).thenReturn(DESTINATION_ACCOUNT_UUID); when(destinationAccount.getIdentifier(IdentityType.ACI)).thenReturn(DESTINATION_ACCOUNT_UUID); @@ -465,7 +465,7 @@ class MessagePersisterTest { assertTimeoutPreemptively(Duration.ofSeconds(1), () -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE, Tags.empty()).block()); - verify(accountsManager, exactly()).removeDevice(destinationAccount, DESTINATION_DEVICE_ID); + verify(accountsManager, exactly()).removeDevice(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID); } @Test @@ -552,7 +552,7 @@ class MessagePersisterTest { when(destinationAccount.getDevices()).thenReturn(List.of(primary, activeA, inactiveB, inactiveC, activeD, destination)); when(messagesManager.persistMessages(any(UUID.class), any(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build()); - when(accountsManager.removeDevice(destinationAccount, DESTINATION_DEVICE_ID)).thenThrow(new RuntimeException()); + when(accountsManager.removeDevice(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID)).thenThrow(new RuntimeException()); assertThrows(RuntimeException.class, () -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE, Tags.empty()).block()); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java index 92318c099..43fae0729 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java @@ -8,7 +8,6 @@ package org.whispersystems.textsecuregcm.tests.util; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.anyLong; -import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockingDetails; import static org.mockito.Mockito.when; @@ -18,16 +17,21 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Objects; import java.util.Optional; -import java.util.Set; import java.util.UUID; +import java.util.concurrent.CompletableFuture; import java.util.function.Consumer; +import org.junit.platform.commons.util.StringUtils; import org.mockito.MockingDetails; import org.mockito.stubbing.Stubbing; import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.DeviceAttributes; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; @@ -64,8 +68,8 @@ public class AccountsHelper { /** * Sets up stubbing for: *
    - *
  • {@link AccountsManager#update(Account, Consumer)}
  • - *
  • {@link AccountsManager#updateDevice(Account, byte, Consumer)}
  • + *
  • {@link AccountsManager#update(UUID, Consumer)}
  • + *
  • {@link AccountsManager#updateDevice(UUID, byte, Consumer)}
  • *
* * with multiple calls to the {@link Consumer}. This simulates retries from {@link org.whispersystems.textsecuregcm.storage.ContestedOptimisticLockException}. @@ -80,67 +84,127 @@ public class AccountsHelper { */ @SuppressWarnings("unchecked") public static void setupMockUpdateWithRetries(final AccountsManager mockAccountsManager, final int retryCount) { - when(mockAccountsManager.update(any(), any())).thenAnswer(answer -> { - final Account account = answer.getArgument(0, Account.class); + when(mockAccountsManager.update(any(UUID.class), any())).thenAnswer(invocation -> { + final UUID accountIdentifier = invocation.getArgument(0, UUID.class); + final Account account = mockAccountsManager.getByAccountIdentifier(accountIdentifier).orElseThrow(); for (int i = 0; i < retryCount; i++) { - answer.getArgument(1, Consumer.class).accept(account); + invocation.getArgument(1, Consumer.class).accept(account); + } + + return account; + }); + + when(mockAccountsManager.update(any(Account.class), any())).thenAnswer(invocation -> { + final Account account = invocation.getArgument(0); + + for (int i = 0; i < retryCount; i++) { + invocation.getArgument(1, Consumer.class).accept(account); } return copyAndMarkStale(account); }); - when(mockAccountsManager.updateDevice(any(), anyByte(), any())).thenAnswer(answer -> { - final Account account = answer.getArgument(0, Account.class); + when(mockAccountsManager.updateDevice(any(UUID.class), anyByte(), any())).thenAnswer(answer -> { + final UUID accountIdentifier = answer.getArgument(0, UUID.class); + final Account account = mockAccountsManager.getByAccountIdentifier(accountIdentifier).orElseThrow(); + final byte deviceId = answer.getArgument(1, Byte.class); for (int i = 0; i < retryCount; i++) { account.getDevice(deviceId).ifPresent(answer.getArgument(2, Consumer.class)); } - return copyAndMarkStale(account); + return account; }); } @SuppressWarnings("unchecked") private static void setupMockUpdate(final AccountsManager mockAccountsManager, final boolean markStale) { - when(mockAccountsManager.update(any(), any())).thenAnswer(answer -> { - final Account account = answer.getArgument(0, Account.class); - answer.getArgument(1, Consumer.class).accept(account); + when(mockAccountsManager.update(any(UUID.class), any())).thenAnswer(invocation -> { + final UUID accountIdentifier = invocation.getArgument(0, UUID.class); + final Account account = mockAccountsManager.getByAccountIdentifier(accountIdentifier).orElseThrow(); + + invocation.getArgument(1, Consumer.class).accept(account); + + return account; + }); + + when(mockAccountsManager.update(any(Account.class), any())).thenAnswer(invocation -> { + final Account account = invocation.getArgument(0); + + invocation.getArgument(1, Consumer.class).accept(account); return markStale ? copyAndMarkStale(account) : account; }); - when(mockAccountsManager.updateDevice(any(), anyByte(), any())).thenAnswer(answer -> { - final Account account = answer.getArgument(0, Account.class); - final byte deviceId = answer.getArgument(1, Byte.class); - account.getDevice(deviceId).ifPresent(answer.getArgument(2, Consumer.class)); + when(mockAccountsManager.updateDevice(any(), anyByte(), any())).thenAnswer(invocation -> { + final UUID accountIdentifier = invocation.getArgument(0, UUID.class); + final Account account = mockAccountsManager.getByAccountIdentifier(accountIdentifier).orElseThrow(); - return markStale ? copyAndMarkStale(account) : account; + final byte deviceId = invocation.getArgument(1, Byte.class); + account.getDevice(deviceId).ifPresent(invocation.getArgument(2, Consumer.class)); + + return account; }); - when(mockAccountsManager.updateDeviceLastSeen(any(), any(), anyLong())).thenAnswer(answer -> { - answer.getArgument(1, Device.class).setLastSeen(answer.getArgument(2, Long.class)); - return mockAccountsManager.update(answer.getArgument(0, Account.class), account -> {}); + when(mockAccountsManager.updateDeviceLastSeen(any(), any(), anyLong())).thenAnswer(invocation -> { + final UUID accountIdentifier = invocation.getArgument(0, UUID.class); + final Account account = mockAccountsManager.getByAccountIdentifier(accountIdentifier).orElseThrow(); + + final Device device = account.getDevice(invocation.getArgument(1, Device.class).getId()).orElseThrow(); + device.setLastSeen(invocation.getArgument(2, Long.class)); + + return mockAccountsManager.update(accountIdentifier, _ -> {}); }); } - public static void setupMockGet(final AccountsManager mockAccountsManager, final Set mockAccounts) { - when(mockAccountsManager.getByAccountIdentifier(any(UUID.class))).thenAnswer(answer -> { + public static void setupMockGet(final AccountsManager mockAccountsManager, final Account account) { + if (account.getUuid() != null || account.getIdentifier(IdentityType.ACI) != null) { + final UUID accountIdentifier = + Objects.requireNonNullElseGet(account.getIdentifier(IdentityType.ACI), account::getUuid); - final UUID uuid = answer.getArgument(0, UUID.class); + when(mockAccountsManager.getByAccountIdentifier(accountIdentifier)) + .thenReturn(Optional.of(account)); - return mockAccounts.stream() - .filter(account -> uuid.equals(account.getUuid())) - .findFirst() - .map(account -> { - try { - return copyAndMarkStale(account); - } catch (final Exception e) { - throw new RuntimeException(e); - } - }); - }); + when(mockAccountsManager.getByAccountIdentifierAsync(accountIdentifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); + + when(mockAccountsManager.getByServiceIdentifier(new AciServiceIdentifier(accountIdentifier))) + .thenReturn(Optional.of(account)); + + when(mockAccountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(accountIdentifier))) + .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); + } + + if (account.getPhoneNumberIdentifier() != null || account.getIdentifier(IdentityType.PNI) != null) { + final UUID phoneNumberIdentifier = + Objects.requireNonNullElseGet(account.getIdentifier(IdentityType.PNI), account::getPhoneNumberIdentifier); + + when(mockAccountsManager.getByPhoneNumberIdentifier(phoneNumberIdentifier)) + .thenReturn(Optional.of(account)); + + when(mockAccountsManager.getByPhoneNumberIdentifierAsync(phoneNumberIdentifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); + + when(mockAccountsManager.getByServiceIdentifier(new PniServiceIdentifier(phoneNumberIdentifier))) + .thenReturn(Optional.of(account)); + + when(mockAccountsManager.getByServiceIdentifierAsync(new PniServiceIdentifier(phoneNumberIdentifier))) + .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); + } + + if (StringUtils.isNotBlank(account.getNumber())) { + when(mockAccountsManager.getByE164(account.getNumber())).thenReturn(Optional.of(account)); + } + + account.getUsernameHash().ifPresent(usernameHash -> when(mockAccountsManager.getByUsernameHash(usernameHash)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(account)))); + + if (account.getUsernameLinkHandle() != null) { + when(mockAccountsManager.getByUsernameLinkHandle(account.getUsernameLinkHandle())) + .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); + } } private static Account copyAndMarkStale(Account account) throws IOException { @@ -192,10 +256,6 @@ public class AccountsHelper { return updatedAccount; } - public static Account eqUuid(Account value) { - return argThat(other -> other.getUuid().equals(value.getUuid())); - } - public static Account createAccount(final AccountsManager accountsManager, final String e164) throws InterruptedException { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java index 608ff6ee1..c007aee61 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java @@ -166,7 +166,7 @@ public class AuthHelper { when(VALID_ACCOUNT_TWO.getUuid()).thenReturn(VALID_UUID_TWO); when(VALID_ACCOUNT_TWO.getPhoneNumberIdentifier()).thenReturn(VALID_PNI_TWO); when(VALID_ACCOUNT_TWO.getIdentifier(IdentityType.ACI)).thenReturn(VALID_UUID_TWO); - when(VALID_ACCOUNT_TWO.getPhoneNumberIdentifier()).thenReturn(VALID_PNI_TWO); + when(VALID_ACCOUNT_TWO.getIdentifier(IdentityType.PNI)).thenReturn(VALID_PNI_TWO); when(UNDISCOVERABLE_ACCOUNT.getNumber()).thenReturn(UNDISCOVERABLE_NUMBER); when(UNDISCOVERABLE_ACCOUNT.getUuid()).thenReturn(UNDISCOVERABLE_UUID); when(UNDISCOVERABLE_ACCOUNT.getPhoneNumberIdentifier()).thenReturn(UNDISCOVERABLE_PNI); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredAccountsCommandTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredAccountsCommandTest.java index bc2943412..d7b893c07 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredAccountsCommandTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredAccountsCommandTest.java @@ -16,12 +16,14 @@ import static org.mockito.Mockito.when; import java.time.Clock; import java.time.Instant; import java.time.ZoneId; +import java.util.UUID; import java.util.stream.Stream; import net.sourceforge.argparse4j.inf.Namespace; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; +import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import reactor.core.publisher.Flux; @@ -64,10 +66,15 @@ class RemoveExpiredAccountsCommandTest { final RemoveExpiredAccountsCommand removeExpiredAccountsCommand = new TestRemoveExpiredAccountsCommand(clock, accountsManager, isDryRun); + final UUID activeAccountIdentifier = UUID.randomUUID(); + final UUID expiredAccountIdentifier = UUID.randomUUID(); + final Account activeAccount = mock(Account.class); + when(activeAccount.getIdentifier(IdentityType.ACI)).thenReturn(activeAccountIdentifier); when(activeAccount.getLastSeen()).thenReturn(clock.instant().toEpochMilli()); final Account expiredAccount = mock(Account.class); + when(expiredAccount.getIdentifier(IdentityType.ACI)).thenReturn(expiredAccountIdentifier); when(expiredAccount.getLastSeen()) .thenReturn(clock.instant().minus(RemoveExpiredAccountsCommand.MAX_IDLE_DURATION).minusMillis(1).toEpochMilli()); @@ -76,8 +83,8 @@ class RemoveExpiredAccountsCommandTest { if (isDryRun) { verify(accountsManager, never()).delete(any(), any()); } else { - verify(accountsManager).delete(expiredAccount, AccountsManager.DeletionReason.EXPIRED); - verify(accountsManager, never()).delete(eq(activeAccount), any()); + verify(accountsManager).delete(expiredAccountIdentifier, AccountsManager.DeletionReason.EXPIRED); + verify(accountsManager, never()).delete(eq(activeAccountIdentifier), any()); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredUsernameHoldsCommandTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredUsernameHoldsCommandTest.java index 644a40b73..8e898eca7 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredUsernameHoldsCommandTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredUsernameHoldsCommandTest.java @@ -23,6 +23,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.Iterator; import java.util.List; +import java.util.UUID; import java.util.concurrent.ThreadLocalRandom; import java.util.function.Consumer; import java.util.stream.IntStream; @@ -31,6 +32,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.ArgumentCaptor; +import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.util.TestClock; @@ -76,7 +78,10 @@ class RemoveExpiredUsernameHoldsCommandTest { final RemoveExpiredUsernameHoldsCommand removeExpiredUsernameHoldsCommand = new TestRemoveExpiredUsernameHoldsCommand(clock, accountsManager, isDryRun); + final UUID hasHoldsAccountIdentifier = UUID.randomUUID(); + final Account hasHolds = mock(Account.class); + when(hasHolds.getIdentifier(IdentityType.ACI)).thenReturn(hasHoldsAccountIdentifier); final List originalHolds = List.of( // expired new Account.UsernameHold(TestRandomUtil.nextBytes(32), Instant.EPOCH.getEpochSecond()), @@ -92,7 +97,7 @@ class RemoveExpiredUsernameHoldsCommandTest { verifyNoInteractions(accountsManager); } else { ArgumentCaptor> updaterCaptor = ArgumentCaptor.forClass(Consumer.class); - verify(accountsManager, times(1)).update(eq(hasHolds), updaterCaptor.capture()); + verify(accountsManager, times(1)).update(eq(hasHoldsAccountIdentifier), updaterCaptor.capture()); final Consumer consumer = updaterCaptor.getValue(); consumer.accept(hasHolds); verify(hasHolds, times(1)).setUsernameHolds(argThat(holds -> diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/workers/UnlinkDevicesWithIdlePrimaryCommandTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/workers/UnlinkDevicesWithIdlePrimaryCommandTest.java index 33ca98a16..9f04bd190 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/workers/UnlinkDevicesWithIdlePrimaryCommandTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/workers/UnlinkDevicesWithIdlePrimaryCommandTest.java @@ -124,7 +124,7 @@ class UnlinkDevicesWithIdlePrimaryCommandTest { accountWithIdlePrimaryAndLinkedDevice)); if (!isDryRun) { - verify(accountsManager).removeDevice(accountWithIdlePrimaryAndLinkedDevice, linkedDeviceId); + verify(accountsManager).removeDevice(accountWithIdlePrimaryAndLinkedDevice.getIdentifier(IdentityType.ACI), linkedDeviceId); } verifyNoMoreInteractions(accountsManager);