Always use fresh/non-cached copies of accounts when making modifications

This commit is contained in:
Jon Chambers 2026-04-13 11:00:29 -04:00 committed by Jon Chambers
parent 1e2d27585a
commit 9e6cbe8f82
46 changed files with 520 additions and 516 deletions

View File

@ -142,7 +142,7 @@ public class AccountAuthenticator implements Authenticator<BasicCredentials, Aut
Metrics.summary(DAYS_SINCE_LAST_SEEN_DISTRIBUTION_NAME, IS_PRIMARY_DEVICE_TAG, String.valueOf(device.isPrimary()))
.record(Duration.ofMillis(todayInMillisWithOffset - device.getLastSeen()).toDays());
return accountsManager.updateDeviceLastSeen(account, device, Util.todayInMillis(clock));
return accountsManager.updateDeviceLastSeen(account.getIdentifier(IdentityType.ACI), device, Util.todayInMillis(clock));
}
return account;

View File

@ -284,7 +284,7 @@ public class BackupAuthManager {
accountsManager.update(account, a -> {
// Receipt credential expirations must be day aligned. Make sure any manually set backupVoucher is also day
// aligned
final Account.BackupVoucher newPayment = new Account.BackupVoucher(
final Account.BackupVoucher newPayment = new Account.BackupVoucher(
backupVoucher.receiptLevel(),
backupVoucher.expiration().truncatedTo(ChronoUnit.DAYS));
final Account.BackupVoucher existingPayment = a.getBackupVoucher();

View File

@ -108,33 +108,26 @@ public class AccountController {
public void setGcmRegistrationId(@Auth AuthenticatedDevice auth,
@NotNull @Valid GcmRegistrationId registrationId) {
final Account account = accounts.getByAccountIdentifier(auth.accountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
final String currentGcmId = accounts.getByAccountIdentifier(auth.accountIdentifier())
.flatMap(account -> account.getDevice(auth.deviceId()))
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED))
.getGcmId();
final Device device = account.getDevice(auth.deviceId())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
if (Objects.equals(device.getGcmId(), registrationId.gcmRegistrationId())) {
if (Objects.equals(currentGcmId, registrationId.gcmRegistrationId())) {
return;
}
accounts.updateDevice(account, device.getId(), d -> {
d.setApnId(null);
d.setGcmId(registrationId.gcmRegistrationId());
d.setFetchesMessages(false);
accounts.updateDevice(auth.accountIdentifier(), auth.deviceId(), device -> {
device.setApnId(null);
device.setGcmId(registrationId.gcmRegistrationId());
device.setFetchesMessages(false);
});
}
@DELETE
@Path("/gcm/")
public void deleteGcmRegistrationId(@Auth AuthenticatedDevice auth) {
final Account account = accounts.getByAccountIdentifier(auth.accountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
final Device device = account.getDevice(auth.deviceId())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
accounts.updateDevice(account, device.getId(), d -> {
accounts.updateDevice(auth.accountIdentifier(), auth.deviceId(), d -> {
d.setGcmId(null);
d.setFetchesMessages(false);
d.setUserAgent("OWA");
@ -148,15 +141,9 @@ public class AccountController {
public void setApnRegistrationId(@Auth AuthenticatedDevice auth,
@NotNull @Valid ApnRegistrationId registrationId) {
final Account account = accounts.getByAccountIdentifier(auth.accountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
final Device device = account.getDevice(auth.deviceId())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
// Unlike FCM tokens, we need current "last updated" timestamps for APNs tokens and so update device records
// unconditionally
accounts.updateDevice(account, device.getId(), d -> {
accounts.updateDevice(auth.accountIdentifier(), auth.deviceId(), d -> {
d.setApnId(registrationId.apnRegistrationId());
d.setGcmId(null);
d.setFetchesMessages(false);
@ -166,20 +153,10 @@ public class AccountController {
@DELETE
@Path("/apn/")
public void deleteApnRegistrationId(@Auth AuthenticatedDevice auth) {
final Account account = accounts.getByAccountIdentifier(auth.accountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
final Device device = account.getDevice(auth.deviceId())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
accounts.updateDevice(account, device.getId(), d -> {
accounts.updateDevice(auth.accountIdentifier(), auth.deviceId(), d -> {
d.setApnId(null);
d.setFetchesMessages(false);
if (d.getId() == 1) {
d.setUserAgent("OWI");
} else {
d.setUserAgent("OWP");
}
d.setUserAgent(d.isPrimary() ? "OWI" : "OWP");
});
}
@ -189,20 +166,14 @@ public class AccountController {
public void setRegistrationLock(@Auth AuthenticatedDevice auth, @NotNull @Valid RegistrationLock accountLock) {
final SaltedTokenHash credentials = SaltedTokenHash.generateFor(accountLock.registrationLock());
final Account account = accounts.getByAccountIdentifier(auth.accountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
accounts.update(account,
accounts.update(auth.accountIdentifier(),
a -> a.setRegistrationLock(credentials.hash(), credentials.salt()));
}
@DELETE
@Path("/registration_lock")
public void removeRegistrationLock(@Auth AuthenticatedDevice auth) {
final Account account = accounts.getByAccountIdentifier(auth.accountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
accounts.update(account, a -> a.setRegistrationLock(null, null));
accounts.update(auth.accountIdentifier(), a -> a.setRegistrationLock(null, null));
}
@PUT
@ -240,7 +211,7 @@ public class AccountController {
throw new ForbiddenException();
}
accounts.updateDevice(account, targetDeviceId, d -> d.setName(deviceName.deviceName()));
accounts.updateDevice(account.getIdentifier(IdentityType.ACI), targetDeviceId, d -> d.setName(deviceName.deviceName()));
}
@PUT
@ -252,10 +223,8 @@ public class AccountController {
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent,
@HeaderParam(HeaderUtils.X_SIGNAL_AGENT) String signalAgent,
@NotNull @Valid AccountAttributes attributes) {
final Account account = accounts.getByAccountIdentifier(auth.accountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
final Account updatedAccount = accounts.update(account, a -> {
final Account updatedAccount = accounts.update(auth.accountIdentifier(), a -> {
a.getDevice(auth.deviceId()).ifPresent(d -> {
d.setFetchesMessages(attributes.getFetchesMessages());
d.setName(attributes.getName());
@ -306,10 +275,7 @@ public class AccountController {
@ApiResponse(responseCode = "204", description = "Username successfully deleted.", useReturnTypeSchema = true)
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
public void deleteUsernameHash(@Auth final AuthenticatedDevice auth) {
final Account account = accounts.getByAccountIdentifier(auth.accountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
accounts.clearUsernameHash(account);
accounts.clearUsernameHash(auth.accountIdentifier());
}
@PUT
@ -332,9 +298,6 @@ public class AccountController {
@Auth final AuthenticatedDevice auth,
@NotNull @Valid final ReserveUsernameHashRequest usernameRequest) throws RateLimitExceededException {
final Account account = accounts.getByAccountIdentifier(auth.accountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
rateLimiters.getUsernameReserveLimiter().validate(auth.accountIdentifier());
for (final byte[] hash : usernameRequest.usernameHashes()) {
@ -345,7 +308,7 @@ public class AccountController {
try {
final AccountsManager.UsernameReservation reservation =
accounts.reserveUsernameHash(account, usernameRequest.usernameHashes());
accounts.reserveUsernameHash(auth.accountIdentifier(), usernameRequest.usernameHashes());
return new ReserveUsernameHashResponse(reservation.reservedUsernameHash());
} catch (final UsernameHashNotAvailableException e) {
@ -374,20 +337,17 @@ public class AccountController {
@Auth final AuthenticatedDevice auth,
@NotNull @Valid final ConfirmUsernameHashRequest confirmRequest) throws RateLimitExceededException {
final Account account = accounts.getByAccountIdentifier(auth.accountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
try {
usernameHashZkProofVerifier.verifyProof(confirmRequest.zkProof(), confirmRequest.usernameHash());
} catch (final BaseUsernameException e) {
throw new WebApplicationException(Response.status(422).build());
}
rateLimiters.getUsernameSetLimiter().validate(account.getUuid());
rateLimiters.getUsernameSetLimiter().validate(auth.accountIdentifier());
try {
final Account updatedAccount = accounts.confirmReservedUsernameHash(
account,
auth.accountIdentifier(),
confirmRequest.usernameHash(),
confirmRequest.encryptedUsername());
@ -474,7 +434,7 @@ public class AccountController {
} else {
usernameLinkHandle = UUID.randomUUID();
}
updateUsernameLink(account, usernameLinkHandle, encryptedUsername.usernameLinkEncryptedValue());
updateUsernameLink(account.getIdentifier(IdentityType.ACI), usernameLinkHandle, encryptedUsername.usernameLinkEncryptedValue());
return new UsernameLinkHandle(usernameLinkHandle);
}
@ -494,10 +454,7 @@ public class AccountController {
// check ratelimiter for username link operations
rateLimiters.forDescriptor(RateLimiters.For.USERNAME_LINK_OPERATION).validate(auth.accountIdentifier());
final Account account = accounts.getByAccountIdentifier(auth.accountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
clearUsernameLink(account);
clearUsernameLink(auth.accountIdentifier());
}
@GET
@ -559,24 +516,21 @@ public class AccountController {
@DELETE
@Path("/me")
public void deleteAccount(@Auth AuthenticatedDevice auth) {
final Account account = accounts.getByAccountIdentifier(auth.accountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
accounts.delete(account, AccountsManager.DeletionReason.USER_REQUEST);
accounts.delete(auth.accountIdentifier(), AccountsManager.DeletionReason.USER_REQUEST);
}
private void clearUsernameLink(final Account account) {
updateUsernameLink(account, null, null);
private void clearUsernameLink(final UUID accountIdentifier) {
updateUsernameLink(accountIdentifier, null, null);
}
private void updateUsernameLink(
final Account account,
final UUID accountIdentifier,
@Nullable final UUID usernameLinkHandle,
@Nullable final byte[] encryptedUsername) {
if ((encryptedUsername == null) ^ (usernameLinkHandle == null)) {
throw new IllegalStateException("Both or neither arguments must be null");
}
accounts.update(account, a -> a.setUsernameLinkDetails(usernameLinkHandle, encryptedUsername));
accounts.update(accountIdentifier, a -> a.setUsernameLinkDetails(usernameLinkHandle, encryptedUsername));
}
private void requireNotAuthenticated(final Optional<AuthenticatedDevice> authenticatedAccount) {

View File

@ -179,10 +179,7 @@ public class AccountControllerV2 {
@Auth AuthenticatedDevice auth,
@NotNull @Valid PhoneNumberDiscoverabilityRequest phoneNumberDiscoverability) {
final Account account = accountsManager.getByAccountIdentifier(auth.accountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
accountsManager.update(account, a -> a.setDiscoverableByPhoneNumber(
accountsManager.update(auth.accountIdentifier(), a -> a.setDiscoverableByPhoneNumber(
phoneNumberDiscoverability.discoverableByPhoneNumber()));
}

View File

@ -159,10 +159,7 @@ public class DeviceController {
throw new ForbiddenException();
}
final Account account = accounts.getByAccountIdentifier(auth.accountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
accounts.removeDevice(account, deviceId);
accounts.removeDevice(auth.accountIdentifier(), deviceId);
}
/**
@ -282,20 +279,21 @@ public class DeviceController {
}
try {
final Pair<Account, Device> accountAndDevice = accounts.addDevice(account, new DeviceSpec(deviceAttributes.name(),
authorizationHeader.getPassword(),
signalAgent,
capabilities,
deviceAttributes.registrationId(),
deviceAttributes.phoneNumberIdentityRegistrationId(),
deviceAttributes.fetchesMessages(),
deviceActivationRequest.apnToken(),
deviceActivationRequest.gcmToken(),
deviceActivationRequest.aciSignedPreKey(),
deviceActivationRequest.pniSignedPreKey(),
deviceActivationRequest.aciPqLastResortPreKey(),
deviceActivationRequest.pniPqLastResortPreKey()),
linkDeviceRequest.verificationCode());
final Pair<Account, Device> accountAndDevice = accounts.addDevice(account.getIdentifier(IdentityType.ACI),
new DeviceSpec(deviceAttributes.name(),
authorizationHeader.getPassword(),
signalAgent,
capabilities,
deviceAttributes.registrationId(),
deviceAttributes.phoneNumberIdentityRegistrationId(),
deviceAttributes.fetchesMessages(),
deviceActivationRequest.apnToken(),
deviceActivationRequest.gcmToken(),
deviceActivationRequest.aciSignedPreKey(),
deviceActivationRequest.pniSignedPreKey(),
deviceActivationRequest.aciPqLastResortPreKey(),
deviceActivationRequest.pniPqLastResortPreKey()),
linkDeviceRequest.verificationCode());
return new LinkDeviceResponse(
accountAndDevice.first().getIdentifier(IdentityType.ACI),
@ -384,15 +382,8 @@ public class DeviceController {
@PUT
@Produces(MediaType.APPLICATION_JSON)
@Path("/capabilities")
public void setCapabilities(@Auth final AuthenticatedDevice auth,
@NotNull
final Map<String, Boolean> capabilities) {
final Account account = accounts.getByAccountIdentifier(auth.accountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
accounts.updateDevice(account, auth.deviceId(),
public void setCapabilities(@Auth final AuthenticatedDevice auth, @NotNull final Map<String, Boolean> capabilities) {
accounts.updateDevice(auth.accountIdentifier(), auth.deviceId(),
d -> d.setCapabilities(DeviceCapabilityAdapter.mapToSet(capabilities)));
}

View File

@ -15,16 +15,12 @@ import jakarta.ws.rs.Consumes;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.core.Response.Status;
import java.time.Clock;
import java.time.Instant;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.function.Function;
import javax.annotation.Nonnull;
import org.glassfish.jersey.server.ManagedAsync;
import org.signal.libsignal.zkgroup.InvalidInputException;
@ -35,7 +31,6 @@ import org.signal.libsignal.zkgroup.receipts.ServerZkReceiptOperations;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.configuration.BadgesConfiguration;
import org.whispersystems.textsecuregcm.entities.RedeemReceiptRequest;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountBadge;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.RedeemedReceiptsManager;
@ -121,9 +116,6 @@ public class DonationController {
.build();
}
final Account account = accountsManager.getByAccountIdentifier(auth.accountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
final boolean receiptMatched = redeemedReceiptsManager.put(
receiptSerial, receiptExpiration.getEpochSecond(), receiptLevel, auth.accountIdentifier()).join();
if (!receiptMatched) {
@ -133,7 +125,7 @@ public class DonationController {
.build();
}
accountsManager.update(account, a -> {
accountsManager.update(auth.accountIdentifier(), a -> {
a.addBadge(clock, new AccountBadge(badgeId, receiptExpiration, request.isVisible()));
if (request.isPrimary()) {
a.makeBadgePrimaryIfExists(clock, badgeId);

View File

@ -215,7 +215,7 @@ public class ProfileController {
currentAvatar.ifPresent(profilesManager::deleteAvatar);
}
accountsManager.update(account, a -> {
accountsManager.update(account.getIdentifier(IdentityType.ACI), a -> {
final List<AccountBadge> updatedBadges = request.badges()
.map(badges -> ProfileHelper.mergeBadgeIdsWithExistingAccountBadges(clock, badgeConfigurationMap, badges, a.getBadges()))

View File

@ -97,8 +97,8 @@ public class AccountsGrpcService extends SimpleAccountsGrpc.AccountsImplBase {
@Override
public DeleteAccountResponse deleteAccount(final DeleteAccountRequest request) {
accountsManager.delete(getAuthenticatedAccount(AuthenticationUtil.requireAuthenticatedPrimaryDevice()),
AccountsManager.DeletionReason.USER_REQUEST);
accountsManager.delete(AuthenticationUtil.requireAuthenticatedPrimaryDevice().accountIdentifier(),
AccountsManager.DeletionReason.USER_REQUEST);
return DeleteAccountResponse.getDefaultInstance();
}
@ -110,7 +110,7 @@ public class AccountsGrpcService extends SimpleAccountsGrpc.AccountsImplBase {
final SaltedTokenHash credentials =
SaltedTokenHash.generateFor(HexFormat.of().withLowerCase().formatHex(request.getRegistrationLock().toByteArray()));
accountsManager.update(getAuthenticatedAccount(AuthenticationUtil.requireAuthenticatedPrimaryDevice()),
accountsManager.update(AuthenticationUtil.requireAuthenticatedPrimaryDevice().accountIdentifier(),
account -> account.setRegistrationLock(credentials.hash(), credentials.salt()));
return SetRegistrationLockResponse.getDefaultInstance();
@ -118,7 +118,7 @@ public class AccountsGrpcService extends SimpleAccountsGrpc.AccountsImplBase {
@Override
public ClearRegistrationLockResponse clearRegistrationLock(final ClearRegistrationLockRequest request) {
accountsManager.update(getAuthenticatedAccount(AuthenticationUtil.requireAuthenticatedPrimaryDevice()),
accountsManager.update(AuthenticationUtil.requireAuthenticatedPrimaryDevice().accountIdentifier(),
account -> account.setRegistrationLock(null, null));
return ClearRegistrationLockResponse.getDefaultInstance();
@ -142,11 +142,9 @@ public class AccountsGrpcService extends SimpleAccountsGrpc.AccountsImplBase {
rateLimiters.getUsernameReserveLimiter().validate(authenticatedDevice.accountIdentifier());
final Account account = getAuthenticatedAccount();
try {
final AccountsManager.UsernameReservation usernameReservation =
accountsManager.reserveUsernameHash(account, usernameHashes);
accountsManager.reserveUsernameHash(authenticatedDevice.accountIdentifier(), usernameHashes);
return ReserveUsernameHashResponse.newBuilder()
.setUsernameHash(ByteString.copyFrom(usernameReservation.reservedUsernameHash()))
@ -172,7 +170,7 @@ public class AccountsGrpcService extends SimpleAccountsGrpc.AccountsImplBase {
rateLimiters.getUsernameSetLimiter().validate(authenticatedDevice.accountIdentifier());
try {
final Account updatedAccount = accountsManager.confirmReservedUsernameHash(getAuthenticatedAccount(),
final Account updatedAccount = accountsManager.confirmReservedUsernameHash(authenticatedDevice.accountIdentifier(),
request.getUsernameHash().toByteArray(),
request.getUsernameCiphertext().toByteArray());
@ -196,7 +194,7 @@ public class AccountsGrpcService extends SimpleAccountsGrpc.AccountsImplBase {
@Override
public DeleteUsernameHashResponse deleteUsernameHash(final DeleteUsernameHashRequest request) {
accountsManager.clearUsernameHash(getAuthenticatedAccount());
accountsManager.clearUsernameHash(AuthenticationUtil.requireAuthenticatedDevice().accountIdentifier());
return DeleteUsernameHashResponse.getDefaultInstance();
}
@ -220,7 +218,8 @@ public class AccountsGrpcService extends SimpleAccountsGrpc.AccountsImplBase {
? account.getUsernameLinkHandle()
: UUID.randomUUID();
accountsManager.update(account, a -> a.setUsernameLinkDetails(linkHandle, request.getUsernameCiphertext().toByteArray()));
accountsManager.update(account.getIdentifier(IdentityType.ACI),
a -> a.setUsernameLinkDetails(linkHandle, request.getUsernameCiphertext().toByteArray()));
return responseBuilder.setUsernameLinkHandle(UUIDUtil.toByteString(linkHandle)).build();
}
@ -232,7 +231,7 @@ public class AccountsGrpcService extends SimpleAccountsGrpc.AccountsImplBase {
rateLimiters.getUsernameLinkOperationLimiter().validate(authenticatedDevice.accountIdentifier());
accountsManager.update(getAuthenticatedAccount(), a -> a.setUsernameLinkDetails(null, null));
accountsManager.update(authenticatedDevice.accountIdentifier(), a -> a.setUsernameLinkDetails(null, null));
return DeleteUsernameLinkResponse.getDefaultInstance();
}
@ -245,7 +244,7 @@ public class AccountsGrpcService extends SimpleAccountsGrpc.AccountsImplBase {
UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH, request.getUnidentifiedAccessKey().size()));
}
accountsManager.update(getAuthenticatedAccount(), account -> {
accountsManager.update(AuthenticationUtil.requireAuthenticatedDevice().accountIdentifier(), account -> {
account.setUnrestrictedUnidentifiedAccess(request.getAllowUnrestrictedUnidentifiedAccess());
account.setUnidentifiedAccessKey(request.getAllowUnrestrictedUnidentifiedAccess() ? null : request.getUnidentifiedAccessKey().toByteArray());
});
@ -255,7 +254,7 @@ public class AccountsGrpcService extends SimpleAccountsGrpc.AccountsImplBase {
@Override
public SetDiscoverableByPhoneNumberResponse setDiscoverableByPhoneNumber(final SetDiscoverableByPhoneNumberRequest request) {
accountsManager.update(getAuthenticatedAccount(),
accountsManager.update(AuthenticationUtil.requireAuthenticatedDevice().accountIdentifier(),
account -> account.setDiscoverableByPhoneNumber(request.getDiscoverableByPhoneNumber()));
return SetDiscoverableByPhoneNumberResponse.getDefaultInstance();
@ -283,7 +282,7 @@ public class AccountsGrpcService extends SimpleAccountsGrpc.AccountsImplBase {
rateLimiters.getSetZkCredentialKeyLimiter().validate(authenticatedDevice.accountIdentifier());
accountsManager.update(authenticatedAccount, account -> account.setZkCredentialKey(zkCredentialKey));
accountsManager.update(authenticatedDevice.accountIdentifier(), account -> account.setZkCredentialKey(zkCredentialKey));
return SetZkCredentialKeyResponse.getDefaultInstance();
}

View File

@ -80,9 +80,7 @@ public class DevicesGrpcService extends SimpleDevicesGrpc.DevicesImplBase {
throw GrpcExceptions.badAuthentication("linked devices cannot remove devices other than themselves");
}
final byte deviceId = DeviceIdUtil.validate(request.getId());
accountsManager.removeDevice(getAuthenticatedAccount(), deviceId);
accountsManager.removeDevice(authenticatedDevice.accountIdentifier(), DeviceIdUtil.validate(request.getId()));
return RemoveDeviceResponse.getDefaultInstance();
}
@ -106,7 +104,7 @@ public class DevicesGrpcService extends SimpleDevicesGrpc.DevicesImplBase {
return SetDeviceNameResponse.newBuilder().setTargetDeviceNotFound(NotFound.getDefaultInstance()).build();
}
accountsManager.updateDevice(account, deviceId, device -> device.setName(request.getName().toByteArray()));
accountsManager.updateDevice(account.getIdentifier(IdentityType.ACI), deviceId, device -> device.setName(request.getName().toByteArray()));
return SetDeviceNameResponse.newBuilder().setSuccess(Empty.getDefaultInstance()).build();
}
@ -141,7 +139,7 @@ public class DevicesGrpcService extends SimpleDevicesGrpc.DevicesImplBase {
.orElseThrow(() -> GrpcExceptions.invalidCredentials("invalid credentials"));
if (!Objects.equals(device.getApnId(), apnsToken) || !Objects.equals(device.getGcmId(), fcmToken)) {
accountsManager.updateDevice(account, authenticatedDevice.deviceId(), d -> {
accountsManager.updateDevice(account.getIdentifier(IdentityType.ACI), authenticatedDevice.deviceId(), d -> {
d.setApnId(apnsToken);
d.setGcmId(fcmToken);
d.setFetchesMessages(false);
@ -154,9 +152,8 @@ public class DevicesGrpcService extends SimpleDevicesGrpc.DevicesImplBase {
@Override
public ClearPushTokenResponse clearPushToken(final ClearPushTokenRequest request) {
final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice();
final Account account = getAuthenticatedAccount();
accountsManager.updateDevice(account, authenticatedDevice.deviceId(), device -> {
accountsManager.updateDevice(authenticatedDevice.accountIdentifier(), authenticatedDevice.deviceId(), device -> {
if (StringUtils.isNotBlank(device.getApnId())) {
device.setUserAgent(device.isPrimary() ? "OWI" : "OWP");
} else if (StringUtils.isNotBlank(device.getGcmId())) {
@ -179,7 +176,7 @@ public class DevicesGrpcService extends SimpleDevicesGrpc.DevicesImplBase {
.map(DeviceCapabilityUtil::fromGrpcDeviceCapability)
.collect(Collectors.toSet());
accountsManager.updateDevice(getAuthenticatedAccount(), authenticatedDevice.deviceId(),
accountsManager.updateDevice(authenticatedDevice.accountIdentifier(), authenticatedDevice.deviceId(),
device -> device.setCapabilities(capabilities));
return SetCapabilitiesResponse.getDefaultInstance();

View File

@ -124,7 +124,7 @@ public class ProfileGrpcService extends SimpleProfileGrpc.ProfileImplBase {
}
};
profilesManager.set(account.getUuid(),
profilesManager.set(account.getIdentifier(IdentityType.ACI),
new VersionedProfile(
request.getVersion(),
request.getName().toByteArray(),
@ -135,7 +135,7 @@ public class ProfileGrpcService extends SimpleProfileGrpc.ProfileImplBase {
request.getPhoneNumberSharing().toByteArray(),
request.getCommitment().toByteArray()));
accountsManager.update(account, a -> {
accountsManager.update(account.getIdentifier(IdentityType.ACI), a -> {
final List<AccountBadge> updatedBadges = Optional.of(request.getBadgeIdsList())
.map(badges -> ProfileHelper.mergeBadgeIdsWithExistingAccountBadges(clock, badgeConfigurationMap, badges,

View File

@ -17,6 +17,7 @@ import java.util.function.BiConsumer;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
@ -195,7 +196,7 @@ public class PushNotificationManager {
// updating an uninstalled feedback timestamp though.
accountsManager.getByAccountIdentifier(account.getUuid()).ifPresent(rereadAccount ->
rereadAccount.getDevice(device.getId()).ifPresent(rereadDevice ->
accountsManager.updateDevice(rereadAccount, device.getId(), d -> {
accountsManager.updateDevice(rereadAccount.getIdentifier(IdentityType.ACI), device.getId(), d -> {
// Don't clear the token if it's already changed
if (originalToken.equals(getPushToken(d, tokenType))) {
switch (tokenType) {

View File

@ -433,11 +433,15 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
return account;
}
public Pair<Account, Device> addDevice(final Account account, final DeviceSpec deviceSpec, final String linkDeviceToken)
public Pair<Account, Device> addDevice(final UUID accountIdentifier, final DeviceSpec deviceSpec, final String linkDeviceToken)
throws LinkDeviceTokenAlreadyUsedException {
return accountLockManager.withLock(Set.of(account.getPhoneNumberIdentifier()),
() -> addDevice(account.getIdentifier(IdentityType.ACI), deviceSpec, linkDeviceToken, MAX_UPDATE_ATTEMPTS),
final UUID phoneNumberIdentifier = accounts.getByAccountIdentifier(accountIdentifier)
.map(account -> account.getIdentifier(IdentityType.PNI))
.orElseThrow(() -> new IllegalArgumentException("Account not found: " + accountIdentifier));
return accountLockManager.withLock(Set.of(phoneNumberIdentifier),
() -> addDevice(accountIdentifier, deviceSpec, linkDeviceToken, MAX_UPDATE_ATTEMPTS),
accountLockExecutor);
}
@ -622,13 +626,18 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
*
* @return the updated Account
*/
public Account removeDevice(final Account account, final byte deviceId) {
public Account removeDevice(final UUID accountIdentifier, final byte deviceId) {
if (deviceId == Device.PRIMARY_ID) {
throw new IllegalArgumentException("Cannot remove primary device");
}
return accountLockManager.withLock(Set.of(account.getPhoneNumberIdentifier()),
() -> removeDevice(account.getIdentifier(IdentityType.ACI), deviceId, MAX_UPDATE_ATTEMPTS),
// Always fetch a fresh, non-cached copy of the account before making modifications
final UUID phoneNumberIdentifier = accounts.getByAccountIdentifier(accountIdentifier)
.map(account -> account.getIdentifier(IdentityType.PNI))
.orElseThrow(() -> new IllegalArgumentException("Account not found: " + accountIdentifier));
return accountLockManager.withLock(Set.of(phoneNumberIdentifier),
() -> removeDevice(accountIdentifier, deviceId, MAX_UPDATE_ATTEMPTS),
accountLockExecutor);
}
@ -672,13 +681,17 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
}
}
public Account changeNumber(final Account account,
public Account changeNumber(final UUID accountIdentifier,
final String targetNumber,
final IdentityKey pniIdentityKey,
final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys,
final Map<Byte, Integer> pniRegistrationIds) throws InterruptedException, MismatchedDevicesException {
// Always fetch a fresh, non-cached copy of the account before making modifications
final Account account = accounts.getByAccountIdentifier(accountIdentifier)
.orElseThrow(() -> new IllegalArgumentException("Account not found: " + accountIdentifier));
final UUID targetPhoneNumberIdentifier = phoneNumberIdentifiers.getPhoneNumberIdentifier(targetNumber).join();
try {
@ -743,7 +756,6 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
buildPniKeyWriteItems(targetPhoneNumberIdentifier, pniSignedPreKeys, pniPqLastResortPreKeys);
return updateWithRetries(
account,
a -> {
setPniKeys(a, pniIdentityKey, pniRegistrationIds);
return true;
@ -815,18 +827,23 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
/// Reserve a username hash so that no other accounts may take it.
///
/// The reserved hash can later be set with [#confirmReservedUsernameHash(Account, byte\[\], byte\[\])]. The
/// The reserved hash can later be set with [#confirmReservedUsernameHash(UUID, byte\[\], byte\[\])]. The
/// reservation will eventually expire, after which point confirmReservedUsernameHash may fail if another account has
/// taken the username hash.
///
/// @param account the account to update
/// @param accountIdentifier the unique identifier of the account to update
/// @param requestedUsernameHashes the list of username hashes to attempt to reserve
///
/// @return the reserved username hash
///
/// @throws UsernameHashNotAvailableException if none of the given username hashes are available
public UsernameReservation reserveUsernameHash(final Account account, final List<byte[]> requestedUsernameHashes)
public UsernameReservation reserveUsernameHash(final UUID accountIdentifier, final List<byte[]> requestedUsernameHashes)
throws UsernameHashNotAvailableException {
// Always fetch a fresh, non-cached copy of the account before making modifications
final Account account = accounts.getByAccountIdentifier(accountIdentifier)
.orElseThrow(() -> new IllegalArgumentException("Account not found: " + accountIdentifier));
if (account.getUsernameHash().filter(
oldHash -> requestedUsernameHashes.stream().anyMatch(hash -> Arrays.equals(oldHash, hash)))
.isPresent()) {
@ -844,7 +861,6 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
redisDelete(account);
final Account updatedAccount = updateWithRetries(
account,
_ -> true,
a -> reservedUsernameHash.set(
checkAndReserveNextUsernameHash(a, new ArrayDeque<>(requestedUsernameHashes))),
@ -873,9 +889,9 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
}
}
/// Set a username hash previously reserved with {@link #reserveUsernameHash(Account, List)}
/// Set a username hash previously reserved with {@link #reserveUsernameHash(UUID, List)}
///
/// @param account the account to update
/// @param accountIdentifier identifier of the account to update
/// @param reservedUsernameHash the previously reserved username hash
/// @param encryptedUsername the encrypted form of the previously reserved username for the username link
///
@ -884,9 +900,13 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
/// @throws UsernameHashNotAvailableException if the reserved username hash has been taken (because the reservation
/// expired)
/// @throws UsernameReservationNotFoundException if `reservedUsernameHash` was not reserved for the account
public Account confirmReservedUsernameHash(final Account account, final byte[] reservedUsernameHash, @Nullable final byte[] encryptedUsername)
public Account confirmReservedUsernameHash(final UUID accountIdentifier, final byte[] reservedUsernameHash, @Nullable final byte[] encryptedUsername)
throws UsernameReservationNotFoundException, UsernameHashNotAvailableException {
// Always fetch a fresh, non-cached copy of the account before making modifications
final Account account = accounts.getByAccountIdentifier(accountIdentifier)
.orElseThrow(() -> new IllegalArgumentException("Account not found: " + accountIdentifier));
if (account.getUsernameHash().map(currentUsernameHash -> Arrays.equals(currentUsernameHash, reservedUsernameHash)).orElse(false)) {
// the client likely already succeeded and is retrying
return account;
@ -900,7 +920,7 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
redisDelete(account);
final Account updatedAccount = updateWithRetries(account,
final Account updatedAccount = updateWithRetries(
_ -> true,
a -> accounts.confirmUsernameHash(a, reservedUsernameHash, encryptedUsername),
() -> accounts.getByAccountIdentifier(account.getUuid()).orElseThrow(),
@ -911,13 +931,10 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
return updatedAccount;
}
public Account clearUsernameHash(final Account account) {
redisDelete(account);
final Account updatedAccount = updateWithRetries(account,
_ -> true,
public Account clearUsernameHash(final UUID accountIdentifier) {
final Account updatedAccount = updateWithRetries(_ -> true,
accounts::clearUsernameHash,
() -> accounts.getByAccountIdentifier(account.getUuid()).orElseThrow(),
() -> accounts.getByAccountIdentifier(accountIdentifier).orElseThrow(),
AccountChangeValidator.USERNAME_CHANGE_VALIDATOR);
redisDelete(updatedAccount);
@ -925,8 +942,15 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
return updatedAccount;
}
public Account update(Account account, Consumer<Account> updater) {
return update(account, a -> {
public Account update(final Account account, final Consumer<Account> updater) {
final Account updatedAccount = update(account.getIdentifier(IdentityType.ACI), updater);
account.markStale();
return updatedAccount;
}
public Account update(final UUID accountIdentifier, final Consumer<Account> updater) {
return update(accountIdentifier, a -> {
updater.accept(a);
// assume that all updaters passed to the public method actually modify the account
return true;
@ -934,11 +958,11 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
}
/**
* Specialized version of {@link #updateDevice(Account, byte, Consumer)} that minimizes potentially contentious and
* Specialized version of {@link #updateDevice(UUID, byte, Consumer)} that minimizes potentially contentious and
* redundant updates of {@code device.lastSeen}
*/
public Account updateDeviceLastSeen(Account account, Device device, final long lastSeen) {
return update(account, a -> {
public Account updateDeviceLastSeen(UUID accountIdentifier, Device device, final long lastSeen) {
return update(accountIdentifier, a -> {
final Optional<Device> maybeDevice = a.getDevice(device.getId());
@ -956,21 +980,16 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
}
/**
* @param account account to update
* @param accountIdentifier identifier of account to update
* @param updater must return {@code true} if the account was actually updated
*/
private Account update(Account account, Function<Account, Boolean> updater) {
private Account update(UUID accountIdentifier, Function<Account, Boolean> updater) {
return updateTimer.record(() -> {
redisDelete(account);
final UUID uuid = account.getUuid();
final Account updatedAccount = updateWithRetries(account,
updater,
final Account updatedAccount = updateWithRetries(updater,
accounts::update,
() -> accounts.getByAccountIdentifier(uuid).orElseThrow(),
() -> accounts.getByAccountIdentifier(accountIdentifier).orElseThrow(),
AccountChangeValidator.GENERAL_CHANGE_VALIDATOR);
redisSet(updatedAccount);
@ -979,49 +998,38 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
});
}
private <E extends Exception> Account updateWithRetries(Account account,
final Function<Account, Boolean> updater,
private <E extends Exception> Account updateWithRetries(final Function<Account, Boolean> updater,
final ThrowingConsumer<Account, E> persister,
final Supplier<Account> retriever,
final AccountChangeValidator changeValidator) throws E {
Account originalAccount = AccountUtil.cloneAccountAsNotStale(account);
if (!updater.apply(account)) {
return account;
}
final int maxTries = 10;
int tries = 0;
while (tries < maxTries) {
try {
persister.accept(account);
final Account updatedAccount = AccountUtil.cloneAccountAsNotStale(account);
account.markStale();
changeValidator.validateChange(originalAccount, updatedAccount);
return updatedAccount;
} catch (final ContestedOptimisticLockException e) {
tries++;
account = retriever.get();
originalAccount = AccountUtil.cloneAccountAsNotStale(account);
final Account account = retriever.get();
final Account originalAccount = AccountUtil.cloneAccountAsNotStale(account);
if (!updater.apply(account)) {
return account;
}
persister.accept(account);
changeValidator.validateChange(originalAccount, account);
return account;
} catch (final ContestedOptimisticLockException e) {
tries++;
}
}
throw new OptimisticLockRetryLimitExceededException();
}
public Account updateDevice(Account account, byte deviceId, Consumer<Device> deviceUpdater) {
return update(account, a -> {
public Account updateDevice(final UUID accountIdentifier, byte deviceId, Consumer<Device> deviceUpdater) {
return update(accountIdentifier, a -> {
a.getDevice(deviceId).ifPresent(deviceUpdater);
// assume that all updaters passed to the public method actually modify the device
return true;
@ -1106,9 +1114,19 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
return accounts.getAll(segments, scheduler);
}
public void delete(final Account account, final DeletionReason deletionReason) {
public void delete(final UUID accountIdentifier, final DeletionReason deletionReason) {
final Timer.Sample sample = Timer.start();
final Optional<Account> maybeAccount = accounts.getByAccountIdentifier(accountIdentifier);
if (maybeAccount.isEmpty()) {
// In most cases, failing to find an account would be an error, but in this case, it's probably a sign that we've
// already succeeded in deleting the account and this is a spurious retry.
return;
}
final Account account = maybeAccount.get();
try {
accountLockManager.withLock(Set.of(account.getPhoneNumberIdentifier()), () -> {
delete(account);
@ -1328,21 +1346,6 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
getAccountEntityKey(account.getUuid())))));
}
private CompletableFuture<Void> redisDeleteAsync(final Account account) {
final Timer.Sample sample = Timer.start();
final String[] keysToDelete = new String[]{
getAccountMapKey(account.getPhoneNumberIdentifier().toString()),
getAccountEntityKey(account.getUuid())
};
return ResilienceUtil.getGeneralRedisRetry(RETRY_NAME).executeCompletionStage(retryExecutor,
() -> cacheCluster.withCluster(connection -> connection.async().del(keysToDelete))
.thenRun(Util.NOOP))
.toCompletableFuture()
.whenComplete((_, _) -> sample.stop(redisDeleteTimer));
}
public CompletableFuture<Optional<DeviceInfo>> waitForNewLinkedDevice(
final UUID accountIdentifier,
final Device linkingDevice,

View File

@ -80,7 +80,7 @@ public class ChangeNumberManager {
}
final Account updatedAccount = accountsManager.changeNumber(
account, number, pniIdentityKey, deviceSignedPreKeys, devicePqLastResortPreKeys, pniRegistrationIds);
account.getIdentifier(IdentityType.ACI), number, pniIdentityKey, deviceSignedPreKeys, devicePqLastResortPreKeys, pniRegistrationIds);
try {
// Now that we've actually updated the account, populate the "updated PNI" field on all envelopes

View File

@ -311,7 +311,7 @@ public class MessagePersister implements Managed {
} else {
logger.warn("Failed to persist queue {}::{} due to overfull queue; will unlink device", accountUuid, deviceId);
return Mono.fromRunnable(() -> accountsManager.removeDevice(account, deviceId))
return Mono.fromRunnable(() -> accountsManager.removeDevice(accountUuid, deviceId))
.subscribeOn(persistQueueScheduler)
.then(Mono.empty());
}

View File

@ -13,6 +13,7 @@ import net.sourceforge.argparse4j.inf.Subparser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.AccountsManager.DeletionReason;
@ -51,7 +52,7 @@ public class DeleteUserCommand extends AbstractCommandWithDependencies {
Optional<Account> account = accountsManager.getByE164(user);
if (account.isPresent()) {
accountsManager.delete(account.get(), DeletionReason.ADMIN_DELETED);
accountsManager.delete(account.get().getIdentifier(IdentityType.ACI), DeletionReason.ADMIN_DELETED);
logger.warn("Removed " + account.get().getNumber());
} else {
logger.warn("Account not found");

View File

@ -16,6 +16,7 @@ import java.time.Instant;
import net.sourceforge.argparse4j.inf.Subparser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import reactor.core.publisher.Mono;
@ -67,7 +68,7 @@ public class RemoveExpiredAccountsCommand extends AbstractSinglePassCrawlAccount
.flatMap(expiredAccount -> {
final Mono<Void> deleteAccountMono = isDryRun
? Mono.empty()
: Mono.fromRunnable(() -> getCommandDependencies().accountsManager().delete(expiredAccount, AccountsManager.DeletionReason.EXPIRED))
: Mono.fromRunnable(() -> getCommandDependencies().accountsManager().delete(expiredAccount.getIdentifier(IdentityType.ACI), AccountsManager.DeletionReason.EXPIRED))
.subscribeOn(Schedulers.boundedElastic())
.then();

View File

@ -21,6 +21,7 @@ import java.util.stream.Collectors;
import net.sourceforge.argparse4j.inf.Subparser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import reactor.core.publisher.Flux;
@ -130,7 +131,7 @@ public class RemoveExpiredLinkedDevicesCommand extends AbstractSinglePassCrawlAc
return Flux.fromIterable(expiredDevices)
.flatMap(deviceId ->
Mono.fromRunnable(() -> getCommandDependencies().accountsManager().removeDevice(account, deviceId))
Mono.fromRunnable(() -> getCommandDependencies().accountsManager().removeDevice(account.getIdentifier(IdentityType.ACI), deviceId))
.retryWhen(Retry.backoff(maxRetries, Duration.ofSeconds(1))
.doAfterRetry(ignored -> retryCounter.increment())
.onRetryExhaustedThrow((spec, rs) -> rs.failure()))

View File

@ -20,6 +20,7 @@ import java.util.concurrent.atomic.AtomicLong;
import net.sourceforge.argparse4j.inf.Subparser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import reactor.core.publisher.Flux;
@ -85,7 +86,7 @@ public class RemoveExpiredUsernameHoldsCommand extends AbstractSinglePassCrawlAc
final int holdsToRemove = removeExpired(holds);
final Mono<Void> purgeMono = isDryRun || holdsToRemove == 0
? Mono.empty()
: Mono.fromRunnable(() -> accountManager.update(account, a -> a.setUsernameHolds(holds)))
: Mono.fromRunnable(() -> accountManager.update(account.getIdentifier(IdentityType.ACI), a -> a.setUsernameHolds(holds)))
.subscribeOn(Schedulers.boundedElastic())
.then();
Metrics.counter(INSPECTED_ACCOUNTS_COUNTER_NAME,

View File

@ -12,6 +12,7 @@ import java.util.UUID;
import net.sourceforge.argparse4j.inf.Namespace;
import net.sourceforge.argparse4j.inf.Subparser;
import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
@ -61,7 +62,7 @@ public class SetUserDiscoverabilityCommand extends AbstractCommandWithDependenci
maybeAccount.ifPresentOrElse(account -> {
final boolean initiallyDiscoverable = account.isDiscoverableByPhoneNumber();
accountsManager.update(account, a -> a.setDiscoverableByPhoneNumber(namespace.getBoolean("discoverable")));
accountsManager.update(account.getIdentifier(IdentityType.ACI), a -> a.setDiscoverableByPhoneNumber(namespace.getBoolean("discoverable")));
System.out.format("Set discoverability flag for %s to %s (was previously %s)\n",
namespace.getString("user"),

View File

@ -13,6 +13,7 @@ import net.sourceforge.argparse4j.impl.Arguments;
import net.sourceforge.argparse4j.inf.Namespace;
import net.sourceforge.argparse4j.inf.Subparser;
import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
@ -51,9 +52,6 @@ public class UnlinkDeviceCommand extends AbstractCommandWithDependencies {
final UUID aci = UUID.fromString(namespace.getString("uuid").trim());
final List<Byte> deviceIds = namespace.getList("deviceIds");
Account account = deps.accountsManager().getByAccountIdentifier(aci)
.orElseThrow(() -> new IllegalArgumentException("account id " + aci + " does not exist"));
if (deviceIds.contains(Device.PRIMARY_ID)) {
throw new IllegalArgumentException("cannot delete primary device");
}
@ -61,7 +59,7 @@ public class UnlinkDeviceCommand extends AbstractCommandWithDependencies {
for (byte deviceId : deviceIds) {
/** see {@link org.whispersystems.textsecuregcm.controllers.DeviceController#removeDevice} */
System.out.format("Removing device %s::%d\n", aci, deviceId);
deps.accountsManager().removeDevice(account, deviceId);
deps.accountsManager().removeDevice(aci, deviceId);
}
}
}

View File

@ -93,7 +93,7 @@ public class UnlinkDevicesWithIdlePrimaryCommand extends AbstractSinglePassCrawl
.filter(account -> isPrimaryDeviceIdle(account, currentTime, idleDurationThreshold))
.flatMap(accountWithIdlePrimaryDevice -> Flux.fromIterable(accountWithIdlePrimaryDevice.getDevices())
.filter(device -> !device.isPrimary())
.map(linkedDevice -> Tuples.of(accountWithIdlePrimaryDevice, linkedDevice.getId())))
.map(linkedDevice -> Tuples.of(accountWithIdlePrimaryDevice.getIdentifier(IdentityType.ACI), linkedDevice.getId())))
.flatMap(accountAndLinkedDeviceId -> {
final Mono<Account> unlinkDeviceMono = isDryRun
? Mono.empty()
@ -104,8 +104,7 @@ public class UnlinkDevicesWithIdlePrimaryCommand extends AbstractSinglePassCrawl
.doOnSuccess(ignored -> unlinkDeviceCounter.increment())
.retryWhen(Retry.backoff(3, Duration.ofSeconds(1)).maxBackoff(Duration.ofSeconds(4)))
.onErrorResume(throwable -> {
logger.warn("Failed to unlink device to delete account {}:{}", accountAndLinkedDeviceId.getT1().getIdentifier(
IdentityType.ACI), accountAndLinkedDeviceId.getT2(), throwable);
logger.warn("Failed to unlink device to delete account {}:{}", accountAndLinkedDeviceId.getT1(), accountAndLinkedDeviceId.getT2(), throwable);
return Mono.empty();
});

View File

@ -64,6 +64,9 @@ class AccountAuthenticatorTest {
oldAccount = AccountsHelper.generateTestAccount("+14088675311", UUID.fromString("adfce52b-9299-4c25-9c51-412fb420c6a6"), UUID.randomUUID(), List.of(generateTestDevice(oldTime)), null);
AccountsHelper.setupMockUpdate(accountsManager);
AccountsHelper.setupMockGet(accountsManager, acct1);
AccountsHelper.setupMockGet(accountsManager, acct2);
AccountsHelper.setupMockGet(accountsManager, oldAccount);
}
private static Device generateTestDevice(final long lastSeen) {
@ -81,17 +84,14 @@ class AccountAuthenticatorTest {
final Device device1 = acct1.getDevices().stream().findFirst().orElseThrow();
final Device device2 = acct2.getDevices().stream().findFirst().orElseThrow();
final Account updatedAcct1 = accountAuthenticator.updateLastSeen(acct1, device1);
final Account updatedAcct2 = accountAuthenticator.updateLastSeen(acct2, device2);
accountAuthenticator.updateLastSeen(acct1, device1);
accountAuthenticator.updateLastSeen(acct2, device2);
verify(accountsManager, never()).updateDeviceLastSeen(eq(acct1), any(), anyLong());
verify(accountsManager).updateDeviceLastSeen(eq(acct2), eq(device2), anyLong());
verify(accountsManager, never()).updateDeviceLastSeen(eq(acct1.getIdentifier(IdentityType.ACI)), any(), anyLong());
verify(accountsManager).updateDeviceLastSeen(eq(acct2.getIdentifier(IdentityType.ACI)), eq(device2), anyLong());
assertThat(device1.getLastSeen()).isEqualTo(yesterday);
assertThat(device2.getLastSeen()).isEqualTo(today);
assertThat(acct1).isSameAs(updatedAcct1);
assertThat(acct2).isNotSameAs(updatedAcct2);
}
@Test
@ -101,17 +101,14 @@ class AccountAuthenticatorTest {
final Device device1 = acct1.getDevices().stream().findFirst().orElseThrow();
final Device device2 = acct2.getDevices().stream().findFirst().orElseThrow();
final Account updatedAcct1 = accountAuthenticator.updateLastSeen(acct1, device1);
final Account updatedAcct2 = accountAuthenticator.updateLastSeen(acct2, device2);
accountAuthenticator.updateLastSeen(acct1, device1);
accountAuthenticator.updateLastSeen(acct2, device2);
verify(accountsManager, never()).updateDeviceLastSeen(eq(acct1), any(), anyLong());
verify(accountsManager, never()).updateDeviceLastSeen(eq(acct2), any(), anyLong());
verify(accountsManager, never()).updateDeviceLastSeen(eq(acct1.getIdentifier(IdentityType.ACI)), any(), anyLong());
verify(accountsManager, never()).updateDeviceLastSeen(eq(acct2.getIdentifier(IdentityType.ACI)), any(), anyLong());
assertThat(device1.getLastSeen()).isEqualTo(yesterday);
assertThat(device2.getLastSeen()).isEqualTo(yesterday);
assertThat(acct1).isSameAs(updatedAcct1);
assertThat(acct2).isSameAs(updatedAcct2);
}
@Test
@ -121,17 +118,14 @@ class AccountAuthenticatorTest {
final Device device1 = acct1.getDevices().stream().findFirst().orElseThrow();
final Device device2 = acct2.getDevices().stream().findFirst().orElseThrow();
final Account updatedAcct1 = accountAuthenticator.updateLastSeen(acct1, device1);
final Account updatedAcct2 = accountAuthenticator.updateLastSeen(acct2, device2);
accountAuthenticator.updateLastSeen(acct1, device1);
accountAuthenticator.updateLastSeen(acct2, device2);
verify(accountsManager).updateDeviceLastSeen(eq(acct1), eq(device1), anyLong());
verify(accountsManager).updateDeviceLastSeen(eq(acct2), eq(device2), anyLong());
verify(accountsManager).updateDeviceLastSeen(eq(acct1.getIdentifier(IdentityType.ACI)), eq(device1), anyLong());
verify(accountsManager).updateDeviceLastSeen(eq(acct2.getIdentifier(IdentityType.ACI)), eq(device2), anyLong());
assertThat(device1.getLastSeen()).isEqualTo(today);
assertThat(device2.getLastSeen()).isEqualTo(today);
assertThat(updatedAcct1).isNotSameAs(acct1);
assertThat(updatedAcct2).isNotSameAs(acct2);
}
@Test
@ -142,7 +136,7 @@ class AccountAuthenticatorTest {
accountAuthenticator.updateLastSeen(oldAccount, device);
verify(accountsManager).updateDeviceLastSeen(eq(oldAccount), eq(device), anyLong());
verify(accountsManager).updateDeviceLastSeen(eq(oldAccount.getIdentifier(IdentityType.ACI)), eq(device), anyLong());
assertThat(device.getLastSeen()).isEqualTo(today);
}

View File

@ -76,10 +76,14 @@ class RegistrationLockVerificationManagerTest {
AccountsHelper.setupMockUpdate(accountsManager);
account = mock(Account.class);
when(account.getUuid()).thenReturn(UUID.randomUUID());
final UUID accountIdentifier = UUID.randomUUID();
when(account.getUuid()).thenReturn(accountIdentifier);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
when(account.getNumber()).thenReturn("+18005551212");
when(account.getDevices()).thenReturn(List.of(device));
AccountsHelper.setupMockGet(accountsManager, account);
existingRegistrationLock = mock(StoredRegistrationLock.class);
when(account.getRegistrationLock()).thenReturn(existingRegistrationLock);
}

View File

@ -55,6 +55,7 @@ import org.signal.libsignal.zkgroup.receipts.ServerZkReceiptOperations;
import org.whispersystems.textsecuregcm.auth.RedemptionRange;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiterConfig;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
@ -62,6 +63,7 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.RedeemedReceiptsManager;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.ExperimentHelper;
import org.whispersystems.textsecuregcm.util.TestClock;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
@ -110,14 +112,10 @@ public class BackupAuthManagerTest {
final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(aci);
when(accountsManager.update(any(), any()))
.thenAnswer(invocation -> {
final Account a = invocation.getArgument(0);
final Consumer<Account> updater = invocation.getArgument(1);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(aci);
updater.accept(a);
return a;
});
AccountsHelper.setupMockGet(accountsManager, account);
AccountsHelper.setupMockUpdate(accountsManager);
final BackupAuthCredentialRequest messagesCredentialRequest = backupAuthTestUtil.getRequest(messagesBackupKey, aci);
final BackupAuthCredentialRequest mediaCredentialRequest = backupAuthTestUtil.getRequest(mediaBackupKey, aci);
@ -135,7 +133,7 @@ public class BackupAuthManagerTest {
void commitOnAnyBackupLevel(final BackupLevel backupLevel) {
final BackupAuthManager authManager = create();
final Account account = new MockAccountBuilder().backupLevel(backupLevel).build();
when(accountsManager.update(any(), any())).thenReturn(account);
when(accountsManager.update(any(Account.class), any())).thenReturn(account);
final ThrowableAssert.ThrowingCallable commit = () ->
authManager.commitBackupId(account,
@ -149,7 +147,7 @@ public class BackupAuthManagerTest {
void commitRequiresPrimary() {
final BackupAuthManager authManager = create();
final Account account = new MockAccountBuilder().build();
when(accountsManager.update(any(), any())).thenReturn(account);
when(accountsManager.update(any(Account.class), any())).thenReturn(account);
final ThrowableAssert.ThrowingCallable commit = () ->
authManager.commitBackupId(account,
@ -306,7 +304,7 @@ public class BackupAuthManagerTest {
.backupVoucher(null)
.build();
when(accountsManager.update(any(), any())).thenReturn(updated);
when(accountsManager.update(any(Account.class), any())).thenReturn(updated);
clock.pin(day2.plus(Duration.ofSeconds(1)));
assertThat(authManager
@ -316,7 +314,7 @@ public class BackupAuthManagerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Consumer<Account>> accountUpdater = ArgumentCaptor.forClass(
Consumer.class);
verify(accountsManager, times(1)).update(any(), accountUpdater.capture());
verify(accountsManager, times(1)).update(any(Account.class), accountUpdater.capture());
// If the account is not expired when we go to update it, we shouldn't wipe it out
final Account alreadyUpdated = mock(Account.class);
@ -341,11 +339,11 @@ public class BackupAuthManagerTest {
.mediaCredential(Optional.of(new byte[0]))
.build();
clock.pin(Instant.EPOCH.plus(Duration.ofDays(1)));
when(accountsManager.update(any(), any())).thenReturn(account);
when(accountsManager.update(any(Account.class), any())).thenReturn(account);
when(redeemedReceiptsManager.put(any(), eq(expirationTime.getEpochSecond()), eq(201L), eq(aci)))
.thenReturn(CompletableFuture.completedFuture(true));
authManager.redeemReceipt(account, receiptPresentation(201, expirationTime));
verify(accountsManager, times(1)).update(any(), any());
verify(accountsManager, times(1)).update(any(Account.class), any());
}
@Test
@ -375,13 +373,13 @@ public class BackupAuthManagerTest {
.build();
clock.pin(Instant.EPOCH.plus(Duration.ofDays(1)));
when(accountsManager.update(any(), any())).thenReturn(account);
when(accountsManager.update(any(Account.class), any())).thenReturn(account);
when(redeemedReceiptsManager.put(any(), eq(newExpirationTime.getEpochSecond()), eq(201L), eq(aci)))
.thenReturn(CompletableFuture.completedFuture(true));
authManager.redeemReceipt(account, receiptPresentation(201, newExpirationTime));
final ArgumentCaptor<Consumer<Account>> updaterCaptor = ArgumentCaptor.captor();
verify(accountsManager, times(1)).update(any(), updaterCaptor.capture());
verify(accountsManager, times(1)).update(any(Account.class), updaterCaptor.capture());
updaterCaptor.getValue().accept(account);
// Should select the voucher with the later expiration time
@ -430,7 +428,7 @@ public class BackupAuthManagerTest {
.build();
clock.pin(Instant.EPOCH.plus(Duration.ofDays(1)));
when(accountsManager.update(any(), any())).thenReturn(account);
when(accountsManager.update(any(Account.class), any())).thenReturn(account);
when(redeemedReceiptsManager.put(any(), eq(expirationTime.getEpochSecond()), eq(201L), eq(aci)))
.thenReturn(CompletableFuture.completedFuture(false));
@ -510,7 +508,7 @@ public class BackupAuthManagerTest {
.backupVoucher(backupVoucher)
.build();
when(accountsManager.update(any(), any())).thenReturn(account);
when(accountsManager.update(any(Account.class), any())).thenReturn(account);
final Optional<BackupAuthCredentialRequest> newMessagesCredential = switch (messageChange) {
case MATCH -> Optional.of(storedMessagesCredential);

View File

@ -293,7 +293,7 @@ class AccountControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
verify(AuthHelper.VALID_DEVICE_3_PRIMARY, times(1)).setGcmId(eq("z000"));
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.VALID_ACCOUNT_3), anyByte(), any());
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.VALID_UUID_3), anyByte(), any());
}
}
@ -322,7 +322,7 @@ class AccountControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
verify(AuthHelper.VALID_DEVICE_3_PRIMARY, times(1)).setApnId(eq("first"));
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.VALID_ACCOUNT_3), anyByte(), any());
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.VALID_UUID_3), anyByte(), any());
}
}
@ -406,7 +406,7 @@ class AccountControllerTest {
}
// make sure `update()` works
doReturn(AuthHelper.VALID_ACCOUNT).when(accountsManager).update(any(), any());
doReturn(AuthHelper.VALID_ACCOUNT).when(accountsManager).update(any(UUID.class), any());
try (final Response response =
builder.put(Entity.json(new EncryptedUsername(TestRandomUtil.nextBytes(payloadSize))))) {
@ -449,7 +449,7 @@ class AccountControllerTest {
}
// make sure `update()` works
doReturn(AuthHelper.VALID_ACCOUNT).when(accountsManager).update(any(), any());
doReturn(AuthHelper.VALID_ACCOUNT).when(accountsManager).update(any(UUID.class), any());
try (final Response delete = builder.delete()) {
assertEquals(expectedStatus, delete.getStatus());
@ -753,7 +753,7 @@ class AccountControllerTest {
@Test
void testDeleteUsername() {
when(accountsManager.clearUsernameHash(any()))
.thenAnswer(invocation -> invocation.getArgument(0));
.thenAnswer(invocation -> accountsManager.getByAccountIdentifier(invocation.getArgument(0)).orElseThrow());
try (final Response response = resources.getJerseyTest()
.target("/v1/accounts/username_hash/")
@ -762,7 +762,7 @@ class AccountControllerTest {
.delete()) {
assertThat(response.getStatus()).isEqualTo(204);
verify(accountsManager).clearUsernameHash(AuthHelper.VALID_ACCOUNT);
verify(accountsManager).clearUsernameHash(AuthHelper.VALID_UUID);
}
}
@ -887,7 +887,7 @@ class AccountControllerTest {
.delete()) {
assertThat(response.getStatus()).isEqualTo(204);
verify(accountsManager).delete(AuthHelper.VALID_ACCOUNT, AccountsManager.DeletionReason.USER_REQUEST);
verify(accountsManager).delete(AuthHelper.VALID_UUID, AccountsManager.DeletionReason.USER_REQUEST);
}
}
@ -903,7 +903,7 @@ class AccountControllerTest {
.delete()) {
assertThat(response.getStatus()).isEqualTo(500);
verify(accountsManager).delete(AuthHelper.VALID_ACCOUNT, AccountsManager.DeletionReason.USER_REQUEST);
verify(accountsManager).delete(AuthHelper.VALID_UUID, AccountsManager.DeletionReason.USER_REQUEST);
}
}
@ -1074,7 +1074,7 @@ class AccountControllerTest {
.put(Entity.json(new DeviceName(TestRandomUtil.nextBytes(64))))) {
assertThat(response.getStatus()).isEqualTo(204);
verify(accountsManager).updateDevice(eq(AuthHelper.VALID_ACCOUNT_3), eq(Device.PRIMARY_ID), any());
verify(accountsManager).updateDevice(eq(AuthHelper.VALID_UUID_3), eq(Device.PRIMARY_ID), any());
}
}
@ -1088,7 +1088,7 @@ class AccountControllerTest {
.put(Entity.json(new DeviceName(TestRandomUtil.nextBytes(64))))) {
assertThat(response.getStatus()).isEqualTo(204);
verify(accountsManager).updateDevice(eq(AuthHelper.VALID_ACCOUNT_3), eq(AuthHelper.VALID_DEVICE_3_LINKED_ID), any());
verify(accountsManager).updateDevice(eq(AuthHelper.VALID_UUID_3), eq(AuthHelper.VALID_DEVICE_3_LINKED_ID), any());
}
}
@ -1102,7 +1102,7 @@ class AccountControllerTest {
.put(Entity.json(new DeviceName(TestRandomUtil.nextBytes(64))))) {
assertThat(response.getStatus()).isEqualTo(204);
verify(accountsManager).updateDevice(eq(AuthHelper.VALID_ACCOUNT_3), eq(AuthHelper.VALID_DEVICE_3_LINKED_ID), any());
verify(accountsManager).updateDevice(eq(AuthHelper.VALID_UUID_3), eq(AuthHelper.VALID_DEVICE_3_LINKED_ID), any());
}
}

View File

@ -30,8 +30,6 @@ import jakarta.ws.rs.client.Entity;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.Base64;
import java.util.EnumSet;
@ -141,7 +139,9 @@ class DeviceControllerTest {
when(account.getNextDeviceId()).thenReturn(NEXT_DEVICE_ID);
when(account.getNumber()).thenReturn(AuthHelper.VALID_NUMBER);
when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(AuthHelper.VALID_UUID);
when(account.getPhoneNumberIdentifier()).thenReturn(AuthHelper.VALID_PNI);
when(account.getIdentifier(IdentityType.PNI)).thenReturn(AuthHelper.VALID_PNI);
when(account.getPrimaryDevice()).thenReturn(primaryDevice);
when(account.getDevice(anyByte())).thenReturn(Optional.empty());
when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(primaryDevice));
@ -244,7 +244,7 @@ class DeviceControllerTest {
when(accountsManager.checkDeviceLinkingToken(anyString())).thenReturn(Optional.of(AuthHelper.VALID_UUID));
when(accountsManager.addDevice(any(), any(), any())).thenAnswer(invocation -> {
final Account a = invocation.getArgument(0);
final Account a = accountsManager.getByAccountIdentifier(invocation.getArgument(0)).orElseThrow();
final DeviceSpec deviceSpec = invocation.getArgument(1);
return new Pair<>(a, deviceSpec.toDevice(NEXT_DEVICE_ID, testClock, aciIdentityKey));
@ -268,7 +268,7 @@ class DeviceControllerTest {
assertThat(response.deviceId()).isEqualTo(NEXT_DEVICE_ID);
final ArgumentCaptor<DeviceSpec> deviceSpecCaptor = ArgumentCaptor.forClass(DeviceSpec.class);
verify(accountsManager).addDevice(eq(account), deviceSpecCaptor.capture(), any());
verify(accountsManager).addDevice(eq(AuthHelper.VALID_UUID), deviceSpecCaptor.capture(), any());
final Device device = deviceSpecCaptor.getValue().toDevice(NEXT_DEVICE_ID, testClock, aciIdentityKey);
@ -796,7 +796,7 @@ class DeviceControllerTest {
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
when(accountsManager.addDevice(any(), any(), any())).thenAnswer(invocation -> {
final Account a = invocation.getArgument(0);
final Account a = accountsManager.getByAccountIdentifier(invocation.getArgument(0)).orElseThrow();
final DeviceSpec deviceSpec = invocation.getArgument(1);
return new Pair<>(a, deviceSpec.toDevice(NEXT_DEVICE_ID, testClock, aciIdentityKey));
@ -886,7 +886,7 @@ class DeviceControllerTest {
@Test
void maxDevicesTest() throws LinkDeviceTokenAlreadyUsedException {
final List<Device> devices = IntStream.range(0, DeviceController.MAX_DEVICES + 1)
.mapToObj(i -> mock(Device.class))
.mapToObj(_ -> mock(Device.class))
.toList();
when(account.getDevices()).thenReturn(devices);
@ -939,7 +939,7 @@ class DeviceControllerTest {
final byte deviceId = 2;
when(accountsManager.removeDevice(account, deviceId))
when(accountsManager.removeDevice(AuthHelper.VALID_UUID, deviceId))
.thenReturn(account);
try (final Response response = resources
@ -953,7 +953,7 @@ class DeviceControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
assertThat(response.hasEntity()).isFalse();
verify(accountsManager).removeDevice(account, deviceId);
verify(accountsManager).removeDevice(AuthHelper.VALID_UUID, deviceId);
}
}
@ -980,7 +980,7 @@ class DeviceControllerTest {
void removeDeviceBySelf() {
final byte deviceId = 2;
when(accountsManager.removeDevice(AuthHelper.VALID_ACCOUNT_3, deviceId))
when(accountsManager.removeDevice(AuthHelper.VALID_UUID_3, deviceId))
.thenReturn(account);
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_3))
@ -997,7 +997,7 @@ class DeviceControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
assertThat(response.hasEntity()).isFalse();
verify(accountsManager).removeDevice(AuthHelper.VALID_ACCOUNT_3, deviceId);
verify(accountsManager).removeDevice(AuthHelper.VALID_UUID_3, deviceId);
}
}
@ -1173,7 +1173,6 @@ class DeviceControllerTest {
void recordTransferArchiveFailed() {
final byte deviceId = Device.PRIMARY_ID + 1;
final int registrationId = 123;
final Instant deviceCreated = Instant.now().truncatedTo(ChronoUnit.MILLIS);
final RemoteAttachmentError transferFailure = new RemoteAttachmentError(RemoteAttachmentError.ErrorType.CONTINUE_WITHOUT_UPLOAD);
when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null));

View File

@ -197,6 +197,7 @@ class ProfileControllerTest {
when(profileAccount.getIdentityKey(IdentityType.ACI)).thenReturn(ACCOUNT_TWO_IDENTITY_KEY);
when(profileAccount.getIdentityKey(IdentityType.PNI)).thenReturn(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY);
when(profileAccount.getUuid()).thenReturn(AuthHelper.VALID_UUID_TWO);
when(profileAccount.getIdentifier(IdentityType.ACI)).thenReturn(AuthHelper.VALID_UUID_TWO);
when(profileAccount.getPhoneNumberIdentifier()).thenReturn(AuthHelper.VALID_PNI_TWO);
when(profileAccount.getCurrentProfileVersion()).thenReturn(Optional.empty());
when(profileAccount.getUsernameHash()).thenReturn(Optional.of(USERNAME_HASH));
@ -207,6 +208,7 @@ class ProfileControllerTest {
capabilitiesAccount = mock(Account.class);
when(capabilitiesAccount.getUuid()).thenReturn(AuthHelper.VALID_UUID);
when(capabilitiesAccount.getIdentifier(IdentityType.ACI)).thenReturn(AuthHelper.VALID_UUID);
when(capabilitiesAccount.getIdentityKey(IdentityType.ACI)).thenReturn(ACCOUNT_IDENTITY_KEY);
when(capabilitiesAccount.getIdentityKey(IdentityType.PNI)).thenReturn(ACCOUNT_PHONE_NUMBER_IDENTITY_KEY);
@ -1010,6 +1012,7 @@ class ProfileControllerTest {
void testGetProfileWithExpiringProfileKeyCredentialVersionNotFound() throws VerificationFailedException {
final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(AuthHelper.VALID_UUID);
when(account.getCurrentProfileVersion()).thenReturn(Optional.of(versionHex("version")));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
@ -1218,6 +1221,7 @@ class ProfileControllerTest {
final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(AuthHelper.VALID_UUID);
when(account.getCurrentProfileVersion()).thenReturn(Optional.of(version));
when(account.getUnidentifiedAccessKey()).thenReturn(Optional.of(UNIDENTIFIED_ACCESS_KEY));
when(account.isIdentifiedBy(new AciServiceIdentifier(AuthHelper.VALID_UUID))).thenReturn(true);

View File

@ -134,7 +134,7 @@ class RegistrationControllerTest {
void setUp() {
when(rateLimiters.getRegistrationLimiter()).thenReturn(registrationLimiter);
when(accountsManager.update(any(), any())).thenAnswer(invocation -> {
when(accountsManager.update(any(UUID.class), any())).thenAnswer(invocation -> {
final Account account = invocation.getArgument(0);
final Consumer<Account> accountUpdater = invocation.getArgument(1);

View File

@ -27,7 +27,6 @@ import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
import java.util.stream.Stream;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
@ -77,6 +76,7 @@ import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.storage.UsernameHashNotAvailableException;
import org.whispersystems.textsecuregcm.storage.UsernameReservationNotFoundException;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.textsecuregcm.util.UsernameHashZkProofVerifier;
@ -97,15 +97,7 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest<AccountsGrpcService, Ac
@Override
protected AccountsGrpcService createServiceBeforeEachTest() {
when(accountsManager.update(any(), any()))
.thenAnswer(invocation -> {
final Account account = invocation.getArgument(0);
final Consumer<Account> updater = invocation.getArgument(1);
updater.accept(account);
return account;
});
AccountsHelper.setupMockUpdate(accountsManager);
final RateLimiters rateLimiters = mock(RateLimiters.class);
when(rateLimiters.getUsernameReserveLimiter()).thenReturn(rateLimiter);
@ -154,15 +146,10 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest<AccountsGrpcService, Ac
@Test
void deleteAccount() {
final Account account = mock(Account.class);
when(accountsManager.getByAccountIdentifier(AUTHENTICATED_ACI))
.thenReturn(Optional.of(account));
final DeleteAccountResponse ignored =
authenticatedServiceStub().deleteAccount(DeleteAccountRequest.newBuilder().build());
verify(accountsManager).delete(account, AccountsManager.DeletionReason.USER_REQUEST);
verify(accountsManager).delete(AUTHENTICATED_ACI, AccountsManager.DeletionReason.USER_REQUEST);
}
@Test
@ -206,7 +193,7 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest<AccountsGrpcService, Ac
() -> authenticatedServiceStub().setRegistrationLock(SetRegistrationLockRequest.newBuilder()
.build()));
verify(accountsManager, never()).update(any(), any());
verify(accountsManager, never()).update(any(UUID.class), any());
}
@Test
@ -219,7 +206,7 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest<AccountsGrpcService, Ac
.setRegistrationLock(ByteString.copyFrom(TestRandomUtil.nextBytes(32)))
.build()));
verify(accountsManager, never()).update(any(), any());
verify(accountsManager, never()).update(any(UUID.class), any());
}
@Test
@ -243,12 +230,13 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest<AccountsGrpcService, Ac
GrpcTestUtils.assertStatusException(Status.INVALID_ARGUMENT,
() -> authenticatedServiceStub().clearRegistrationLock(ClearRegistrationLockRequest.newBuilder().build()));
verify(accountsManager, never()).update(any(), any());
verify(accountsManager, never()).update(any(UUID.class), any());
}
@Test
void reserveUsernameHash() throws UsernameHashNotAvailableException {
final Account account = mock(Account.class);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(AUTHENTICATED_ACI);
when(accountsManager.getByAccountIdentifier(AUTHENTICATED_ACI))
.thenReturn(Optional.of(account));
@ -259,7 +247,7 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest<AccountsGrpcService, Ac
.thenAnswer(invocation -> {
final List<byte[]> usernameHashes = invocation.getArgument(1);
return new AccountsManager.UsernameReservation(invocation.getArgument(0), usernameHashes.getFirst());
return new AccountsManager.UsernameReservation(accountsManager.getByAccountIdentifier(invocation.getArgument(0)).orElseThrow(), usernameHashes.getFirst());
});
final ReserveUsernameHashResponse expectedResponse = ReserveUsernameHashResponse.newBuilder()
@ -363,7 +351,7 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest<AccountsGrpcService, Ac
when(accountsManager.getByAccountIdentifier(AUTHENTICATED_ACI))
.thenReturn(Optional.of(account));
when(accountsManager.confirmReservedUsernameHash(account, usernameHash, usernameCiphertext))
when(accountsManager.confirmReservedUsernameHash(AUTHENTICATED_ACI, usernameHash, usernameCiphertext))
.thenAnswer(_ -> {
final Account updatedAccount = mock(Account.class);
@ -502,12 +490,12 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest<AccountsGrpcService, Ac
when(accountsManager.getByAccountIdentifier(AUTHENTICATED_ACI))
.thenReturn(Optional.of(account));
when(accountsManager.clearUsernameHash(account)).thenReturn(account);
when(accountsManager.clearUsernameHash(AUTHENTICATED_ACI)).thenReturn(account);
assertDoesNotThrow(() ->
authenticatedServiceStub().deleteUsernameHash(DeleteUsernameHashRequest.newBuilder().build()));
verify(accountsManager).clearUsernameHash(account);
verify(accountsManager).clearUsernameHash(AUTHENTICATED_ACI);
}
@ParameterizedTest
@ -515,6 +503,7 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest<AccountsGrpcService, Ac
void setUsernameLink(final boolean keepLink) {
final Account account = mock(Account.class);
final UUID oldHandle = UUID.randomUUID();
when(account.getIdentifier(IdentityType.ACI)).thenReturn(AUTHENTICATED_ACI);
when(account.getUsernameHash()).thenReturn(Optional.of(new byte[AccountController.USERNAME_HASH_LENGTH]));
when(account.getUsernameLinkHandle()).thenReturn(oldHandle);
@ -740,7 +729,7 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest<AccountsGrpcService, Ac
final int updateMethodCalls = matchesCurrentZkCredentialKey ? 0 : 1;
verify(accountsManager, times(updateMethodCalls)).update(eq(account), any(Consumer.class));
verify(accountsManager, times(updateMethodCalls)).update(eq(AUTHENTICATED_ACI), any());
verify(account, times(updateMethodCalls)).setZkCredentialKey(aryEq(publicKey));
}
@ -763,7 +752,7 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest<AccountsGrpcService, Ac
.setPublicKey(ByteString.copyFrom(publicKey))
.build()));
verify(accountsManager, never()).update(any(Account.class), any(Consumer.class));
verify(accountsManager, never()).update(any(UUID.class), any());
verify(account, never()).setZkCredentialKey(any());
}
}

View File

@ -25,7 +25,6 @@ import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Consumer;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.junit.jupiter.api.Test;
@ -53,6 +52,7 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DeviceCapability;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, DevicesGrpc.DevicesBlockingStub> {
@ -66,6 +66,7 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
@Override
protected DevicesGrpcService createServiceBeforeEachTest() {
when(authenticatedAccount.getUuid()).thenReturn(AUTHENTICATED_ACI);
when(authenticatedAccount.getIdentifier(IdentityType.ACI)).thenReturn(AUTHENTICATED_ACI);
when(accountsManager.getByAccountIdentifier(AUTHENTICATED_ACI))
.thenReturn(Optional.of(authenticatedAccount));
@ -73,26 +74,7 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
when(accountsManager.removeDevice(any(), anyByte()))
.thenReturn(authenticatedAccount);
when(accountsManager.update(any(), any()))
.thenAnswer(invocation -> {
final Account account = invocation.getArgument(0);
final Consumer<Account> updater = invocation.getArgument(1);
updater.accept(account);
return account;
});
when(accountsManager.updateDevice(any(), anyByte(), any()))
.thenAnswer(invocation -> {
final Account account = invocation.getArgument(0);
final Device device = account.getDevice(invocation.getArgument(1)).orElseThrow();
final Consumer<Device> updater = invocation.getArgument(2);
updater.accept(device);
return account;
});
AccountsHelper.setupMockUpdate(accountsManager);
return new DevicesGrpcService(accountsManager);
}
@ -154,7 +136,7 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
.setId(deviceId)
.build());
verify(accountsManager).removeDevice(authenticatedAccount, deviceId);
verify(accountsManager).removeDevice(AUTHENTICATED_ACI, deviceId);
}
@Test
@ -176,7 +158,7 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
.setId(deviceId)
.build());
verify(accountsManager).removeDevice(authenticatedAccount, deviceId);
verify(accountsManager).removeDevice(AUTHENTICATED_ACI, deviceId);
}
@Test

View File

@ -29,8 +29,6 @@ import com.google.common.net.InetAddresses;
import com.google.i18n.phonenumbers.PhoneNumberUtil;
import com.google.i18n.phonenumbers.Phonenumber;
import com.google.protobuf.ByteString;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import java.nio.charset.StandardCharsets;
@ -41,7 +39,6 @@ import java.time.ZoneId;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
@ -71,7 +68,6 @@ import org.signal.chat.profile.SetProfileRequest.AvatarChange;
import org.signal.chat.profile.SetProfileResponse;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ServiceId;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.zkgroup.InvalidInputException;
import org.signal.libsignal.zkgroup.ServerPublicParams;
@ -181,6 +177,7 @@ public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcServic
when(dynamicConfiguration.getPaymentsConfiguration()).thenReturn(dynamicPaymentsConfiguration);
when(account.getUuid()).thenReturn(AUTHENTICATED_ACI);
when(account.getIdentifier(org.whispersystems.textsecuregcm.identity.IdentityType.ACI)).thenReturn(AUTHENTICATED_ACI);
when(account.getNumber()).thenReturn(phoneNumber);
when(account.getBadges()).thenReturn(Collections.emptyList());
@ -188,7 +185,7 @@ public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcServic
when(profile.avatar()).thenReturn("");
when(accountsManager.getByAccountIdentifier(any())).thenReturn(Optional.of(account));
when(accountsManager.update(any(), any())).thenReturn(null);
when(accountsManager.update(any(UUID.class), any())).thenReturn(null);
when(profilesManager.get(any(), any())).thenReturn(Optional.of(profile));

View File

@ -24,6 +24,7 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
@ -150,12 +151,17 @@ class PushNotificationManagerTest {
@Test
void testSendNotificationFcm() {
final UUID accountIdentifier = UUID.randomUUID();
final Account account = mock(Account.class);
final Device device = mock(Device.class);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(device));
AccountsHelper.setupMockGet(accountsManager, account);
final PushNotification pushNotification = new PushNotification(
"token", PushNotification.TokenType.FCM, PushNotification.NotificationType.NOTIFICATION, null, account, device, true);
@ -166,7 +172,7 @@ class PushNotificationManagerTest {
verify(fcmSender).sendNotification(pushNotification);
verifyNoInteractions(apnSender);
verify(accountsManager, never()).updateDevice(eq(account), eq(Device.PRIMARY_ID), any());
verify(accountsManager, never()).updateDevice(eq(accountIdentifier), eq(Device.PRIMARY_ID), any());
verify(device, never()).setGcmId(any());
verifyNoInteractions(pushNotificationScheduler);
}
@ -221,6 +227,7 @@ class PushNotificationManagerTest {
when(device.getGcmId()).thenReturn("token");
when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(device));
when(account.getUuid()).thenReturn(aci);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(aci);
when(accountsManager.getByAccountIdentifier(aci)).thenReturn(Optional.of(account));
final PushNotification pushNotification = new PushNotification(
@ -231,7 +238,7 @@ class PushNotificationManagerTest {
pushNotificationManager.sendNotification(pushNotification);
verify(accountsManager).updateDevice(eq(account), eq(Device.PRIMARY_ID), any());
verify(accountsManager).updateDevice(eq(aci), eq(Device.PRIMARY_ID), any());
verify(device).setGcmId(null);
verifyNoInteractions(apnSender);
verifyNoInteractions(pushNotificationScheduler);
@ -246,6 +253,7 @@ class PushNotificationManagerTest {
when(device.getApnId()).thenReturn("apns-token");
when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(device));
when(account.getUuid()).thenReturn(aci);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(aci);
when(accountsManager.getByAccountIdentifier(aci)).thenReturn(Optional.of(account));
final PushNotification pushNotification = new PushNotification(
@ -260,7 +268,7 @@ class PushNotificationManagerTest {
pushNotificationManager.sendNotification(pushNotification);
verifyNoInteractions(fcmSender);
verify(accountsManager).updateDevice(eq(account), eq(Device.PRIMARY_ID), any());
verify(accountsManager).updateDevice(eq(aci), eq(Device.PRIMARY_ID), any());
verify(device).setApnId(null);
verify(pushNotificationScheduler).cancelScheduledNotifications(account, device);
}
@ -291,7 +299,7 @@ class PushNotificationManagerTest {
pushNotificationManager.sendNotification(pushNotification);
verifyNoInteractions(fcmSender);
verify(accountsManager, never()).updateDevice(eq(account), eq(Device.PRIMARY_ID), any());
verify(accountsManager, never()).updateDevice(eq(aci), eq(Device.PRIMARY_ID), any());
verify(device, never()).setApnId(any());
verify(pushNotificationScheduler, never()).cancelScheduledNotifications(account, device);
}

View File

@ -463,7 +463,7 @@ public class AccountCreationDeletionIntegrationTest {
assertTrue(accountsManager.getByAccountIdentifier(aci).isPresent());
accountsManager.delete(account, AccountsManager.DeletionReason.ADMIN_DELETED);
accountsManager.delete(account.getIdentifier(IdentityType.ACI), AccountsManager.DeletionReason.ADMIN_DELETED);
assertFalse(accountsManager.getByAccountIdentifier(aci).isPresent());
assertFalse(keysManager.getEcSignedPreKey(account.getUuid(), Device.PRIMARY_ID).join().isPresent());
@ -471,7 +471,8 @@ public class AccountCreationDeletionIntegrationTest {
assertFalse(keysManager.getLastResort(account.getUuid(), Device.PRIMARY_ID).join().isPresent());
assertFalse(keysManager.getLastResort(account.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join().isPresent());
verify(disconnectionRequestManager).requestDisconnection(account);
verify(disconnectionRequestManager).requestDisconnection(argThat(disconnectedAccount ->
disconnectedAccount.getIdentifier(IdentityType.ACI).equals(account.getIdentifier(IdentityType.ACI))));
}
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")

View File

@ -169,7 +169,7 @@ class AccountsManagerChangeNumberIntegrationTest {
final ECKeyPair pniIdentityKeyPair = ECKeyPair.generate();
accountsManager.changeNumber(account,
accountsManager.changeNumber(originalUuid,
secondNumber,
new IdentityKey(pniIdentityKeyPair.getPublicKey()),
Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)),
@ -197,7 +197,7 @@ class AccountsManagerChangeNumberIntegrationTest {
final ECKeyPair pniIdentityKeyPair = ECKeyPair.generate();
accountsManager.changeNumber(account,
accountsManager.changeNumber(originalUuid,
originalNumber,
new IdentityKey(pniIdentityKeyPair.getPublicKey()),
Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)),
@ -235,7 +235,7 @@ class AccountsManagerChangeNumberIntegrationTest {
final Map<Byte, KEMSignedPreKey> kemSignedPreKeys = Map.of(Device.PRIMARY_ID, rotatedKemSignedPreKey);
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, rotatedPniRegistrationId);
final Account updatedAccount = accountsManager.changeNumber(account, secondNumber, pniIdentityKey, preKeys, kemSignedPreKeys, registrationIds);
final Account updatedAccount = accountsManager.changeNumber(originalUuid, secondNumber, pniIdentityKey, preKeys, kemSignedPreKeys, registrationIds);
final UUID secondPni = updatedAccount.getPhoneNumberIdentifier();
assertTrue(accountsManager.getByE164(originalNumber).isEmpty());
@ -270,7 +270,7 @@ class AccountsManagerChangeNumberIntegrationTest {
final ECKeyPair originalIdentityKeyPair = ECKeyPair.generate();
final ECKeyPair secondIdentityKeyPair = ECKeyPair.generate();
account = accountsManager.changeNumber(account,
account = accountsManager.changeNumber(originalUuid,
secondNumber,
new IdentityKey(secondIdentityKeyPair.getPublicKey()),
Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, secondIdentityKeyPair)),
@ -279,7 +279,7 @@ class AccountsManagerChangeNumberIntegrationTest {
final UUID secondPni = account.getPhoneNumberIdentifier();
accountsManager.changeNumber(account,
accountsManager.changeNumber(originalUuid,
originalNumber,
new IdentityKey(originalIdentityKeyPair.getPublicKey()),
Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(3, originalIdentityKeyPair)),
@ -315,7 +315,7 @@ class AccountsManagerChangeNumberIntegrationTest {
final UUID existingAccountUuid = existingAccount.getUuid();
accountsManager.changeNumber(account,
accountsManager.changeNumber(originalUuid,
secondNumber,
new IdentityKey(secondIdentityKeyPair.getPublicKey()),
Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, secondIdentityKeyPair)),
@ -337,7 +337,7 @@ class AccountsManagerChangeNumberIntegrationTest {
assertEquals(Optional.of(existingAccountUuid), accountsManager.findRecentlyDeletedAccountIdentifier(originalPni));
assertEquals(Optional.empty(), accountsManager.findRecentlyDeletedAccountIdentifier(secondPni));
accountsManager.changeNumber(accountsManager.getByAccountIdentifier(originalUuid).orElseThrow(),
accountsManager.changeNumber(originalUuid,
originalNumber,
new IdentityKey(originalIdentityKeyPair.getPublicKey()),
Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, originalIdentityKeyPair)),
@ -364,7 +364,7 @@ class AccountsManagerChangeNumberIntegrationTest {
final UUID existingAccountUuid = existingAccount.getUuid();
final ECKeyPair pniIdentityKeyPair = ECKeyPair.generate();
final Account changedNumberAccount = accountsManager.changeNumber(account,
final Account changedNumberAccount = accountsManager.changeNumber(originalUuid,
secondNumber,
new IdentityKey(pniIdentityKeyPair.getPublicKey()),
Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)),
@ -383,7 +383,7 @@ class AccountsManagerChangeNumberIntegrationTest {
final ECKeyPair reRegisteredPniIdentityKeyPair = ECKeyPair.generate();
final Account changedNumberReRegisteredAccount = accountsManager.changeNumber(reRegisteredAccount,
final Account changedNumberReRegisteredAccount = accountsManager.changeNumber(reRegisteredAccount.getIdentifier(IdentityType.ACI),
secondNumber,
new IdentityKey(reRegisteredPniIdentityKeyPair.getPublicKey()),
Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, reRegisteredPniIdentityKeyPair)),

View File

@ -159,7 +159,7 @@ class AccountsManagerConcurrentModificationIntegrationTest {
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair)),
null),
null).getIdentifier(IdentityType.ACI),
a -> {
a.setUnidentifiedAccessKey(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
a.removeDevice(Device.PRIMARY_ID);
@ -226,18 +226,10 @@ class AccountsManagerConcurrentModificationIntegrationTest {
}
private CompletableFuture<?> modifyAccount(final UUID uuid, final Consumer<Account> accountMutation) {
return CompletableFuture.runAsync(() -> {
final Account account = accountsManager.getByAccountIdentifier(uuid).orElseThrow();
accountsManager.update(account, accountMutation);
}, mutationExecutor);
return CompletableFuture.runAsync(() -> accountsManager.update(uuid, accountMutation), mutationExecutor);
}
private CompletableFuture<?> modifyDevice(final UUID uuid, final byte deviceId, final Consumer<Device> deviceMutation) {
return CompletableFuture.runAsync(() -> {
final Account account = accountsManager.getByAccountIdentifier(uuid).orElseThrow();
accountsManager.updateDevice(account, deviceId, deviceMutation);
}, mutationExecutor);
return CompletableFuture.runAsync(() -> accountsManager.updateDevice(uuid, deviceId, deviceMutation), mutationExecutor);
}
}

View File

@ -603,6 +603,7 @@ class AccountsManagerTest {
UUID uuid = UUID.randomUUID();
UUID pni = UUID.randomUUID();
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
addRetrievableAccount(account);
when(clusterCommands.get(eq("Account3::" + uuid))).thenReturn(null);
@ -614,40 +615,21 @@ class AccountsManagerTest {
final IdentityKey identityKey = new IdentityKey(ECKeyPair.generate().getPublicKey());
account = accountsManager.update(account, a -> a.setIdentityKey(identityKey));
account = accountsManager.update(uuid, a -> a.setIdentityKey(identityKey));
assertEquals(1, account.getVersion());
assertEquals(identityKey, account.getIdentityKey(IdentityType.ACI));
verify(accounts, times(1)).getByAccountIdentifier(uuid);
verify(accounts, times(2)).getByAccountIdentifier(uuid);
verify(accounts, times(2)).update(any());
verifyNoMoreInteractions(accounts);
}
@Test
void testUpdate_dynamoOptimisticLockingFailureDuringCreate() throws AccountAlreadyExistsException {
UUID uuid = UUID.randomUUID();
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(clusterCommands.get(eq("Account3::" + uuid))).thenReturn(null);
when(accounts.getByAccountIdentifier(uuid)).thenReturn(Optional.empty())
.thenReturn(Optional.of(account));
when(accounts.create(any(), any())).thenThrow(ContestedOptimisticLockException.class);
accountsManager.update(account, a -> {
});
verify(accounts, times(1)).update(account);
verifyNoMoreInteractions(accounts);
}
@Test
void testUpdateDevice() {
final UUID uuid = UUID.randomUUID();
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(accounts.getByAccountIdentifier(uuid)).thenReturn(
Optional.of(AccountsHelper.generateTestAccount("+14152222222", uuid, UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH])));
addRetrievableAccount(account);
assertTrue(account.getDevices().isEmpty());
@ -661,14 +643,14 @@ class AccountsManagerTest {
@SuppressWarnings("unchecked") Consumer<Device> deviceUpdater = mock(Consumer.class);
@SuppressWarnings("unchecked") Consumer<Device> unknownDeviceUpdater = mock(Consumer.class);
account = accountsManager.updateDevice(account, deviceId, deviceUpdater);
account = accountsManager.updateDevice(account, deviceId, d -> d.setName("deviceName".getBytes(StandardCharsets.UTF_8)));
account = accountsManager.updateDevice(uuid, deviceId, deviceUpdater);
account = accountsManager.updateDevice(uuid, deviceId, d -> d.setName("deviceName".getBytes(StandardCharsets.UTF_8)));
assertArrayEquals("deviceName".getBytes(StandardCharsets.UTF_8), account.getDevice(deviceId).orElseThrow().getName());
verify(deviceUpdater, times(1)).accept(any(Device.class));
accountsManager.updateDevice(account, account.getNextDeviceId(), unknownDeviceUpdater);
accountsManager.updateDevice(uuid, account.getNextDeviceId(), unknownDeviceUpdater);
verify(unknownDeviceUpdater, never()).accept(any(Device.class));
}
@ -689,7 +671,7 @@ class AccountsManagerTest {
assertTrue(account.getDevice(linkedDevice.getId()).isPresent());
account = accountsManager.removeDevice(account, linkedDevice.getId());
account = accountsManager.removeDevice(account.getIdentifier(IdentityType.ACI), linkedDevice.getId());
assertFalse(account.getDevice(linkedDevice.getId()).isPresent());
verify(messagesManager, times(2)).clear(account.getUuid(), linkedDevice.getId());
@ -708,7 +690,8 @@ class AccountsManagerTest {
when(keysManager.deleteSingleUsePreKeys(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null));
when(messagesManager.clear(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null));
assertThrows(IllegalArgumentException.class, () -> accountsManager.removeDevice(account, Device.PRIMARY_ID));
assertThrows(IllegalArgumentException.class,
() -> accountsManager.removeDevice(account.getIdentifier(IdentityType.ACI), Device.PRIMARY_ID));
assertDoesNotThrow(account::getPrimaryDevice);
verify(messagesManager, never()).clear(any(), anyByte());
@ -876,7 +859,7 @@ class AccountsManagerTest {
CLOCK.pin(CLOCK.instant().plusSeconds(60));
final Pair<Account, Device> updatedAccountAndDevice = accountsManager.addDevice(account, new DeviceSpec(
final Pair<Account, Device> updatedAccountAndDevice = accountsManager.addDevice(aci, new DeviceSpec(
deviceNameCiphertext,
password,
signalAgent,
@ -922,10 +905,12 @@ class AccountsManagerTest {
@MethodSource
void testUpdateDeviceLastSeen(final boolean expectUpdate, final long initialLastSeen, final long updatedLastSeen) {
final Account account = AccountsHelper.generateTestAccount("+14152222222", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
addRetrievableAccount(account);
final Device device = generateTestDevice(initialLastSeen);
account.addDevice(device);
accountsManager.updateDeviceLastSeen(account, device, updatedLastSeen);
accountsManager.updateDeviceLastSeen(account.getIdentifier(IdentityType.ACI), device, updatedLastSeen);
assertEquals(expectUpdate ? updatedLastSeen : initialLastSeen, device.getLastSeen());
verify(accounts, expectUpdate ? times(1) : never()).update(account);
@ -957,7 +942,8 @@ class AccountsManagerTest {
final KEMSignedPreKey kemLastResortPreKey = KeysHelper.signedKEMPreKey(2, pniIdentityKeyPair);
Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, List.of(DevicesHelper.createDevice(Device.PRIMARY_ID)), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
account = accountsManager.changeNumber(account,
addRetrievableAccount(account);
account = accountsManager.changeNumber(uuid,
targetNumber,
new IdentityKey(pniIdentityKeyPair.getPublicKey()),
Map.of(Device.PRIMARY_ID, ecSignedPreKey),
@ -985,9 +971,11 @@ class AccountsManagerTest {
Account account = AccountsHelper.generateTestAccount(originalNumber, UUID.randomUUID(), phoneNumberIdentifier,
List.of(DevicesHelper.createDevice(Device.PRIMARY_ID)), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
addRetrievableAccount(account);
phoneNumberIdentifiersByE164.put(originalNumber, account.getPhoneNumberIdentifier());
phoneNumberIdentifiersByE164.put(newNumber, account.getPhoneNumberIdentifier());
account = accountsManager.changeNumber(account,
account = accountsManager.changeNumber(account.getIdentifier(IdentityType.ACI),
newNumber,
new IdentityKey(pniIdentityKeyPair.getPublicKey()),
Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)),
@ -1016,7 +1004,9 @@ class AccountsManagerTest {
final KEMSignedPreKey kemLastResoryPreKey = KeysHelper.signedKEMPreKey(2, pniIdentityKeyPair);
Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, List.of(DevicesHelper.createDevice(Device.PRIMARY_ID)), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
account = accountsManager.changeNumber(account,
addRetrievableAccount(account);
account = accountsManager.changeNumber(uuid,
targetNumber,
new IdentityKey(pniIdentityKeyPair.getPublicKey()),
Map.of(Device.PRIMARY_ID, ecSignedPreKey),
@ -1064,8 +1054,9 @@ class AccountsManagerTest {
DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102));
final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
addRetrievableAccount(account);
final Account updatedAccount = accountsManager.changeNumber(
account, targetNumber, new IdentityKey(ECKeyPair.generate().getPublicKey()), newSignedKeys, newSignedPqKeys, newRegistrationIds);
uuid, targetNumber, new IdentityKey(ECKeyPair.generate().getPublicKey()), newSignedKeys, newSignedPqKeys, newRegistrationIds);
assertEquals(targetNumber, updatedAccount.getNumber());
@ -1102,11 +1093,12 @@ class AccountsManagerTest {
final List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102));
final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
addRetrievableAccount(account);
assertThrows(MismatchedDevicesException.class,
() -> accountsManager.changeNumber(
account, targetNumber, new IdentityKey(ECKeyPair.generate().getPublicKey()), newSignedKeys, newSignedPqKeys, newRegistrationIds));
uuid, targetNumber, new IdentityKey(ECKeyPair.generate().getPublicKey()), newSignedKeys, newSignedPqKeys, newRegistrationIds));
verifyNoInteractions(accounts);
verify(accounts, never()).changeNumber(any(), any(), any(), any(), any());
verifyNoInteractions(keysManager);
}
@ -1117,8 +1109,9 @@ class AccountsManagerTest {
final UUID uuid = UUID.randomUUID();
final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
addRetrievableAccount(account);
assertThrows(AssertionError.class, () -> accountsManager.update(account, a -> a.setNumber(targetNumber, UUID.randomUUID())));
assertThrows(AssertionError.class, () -> accountsManager.update(uuid, a -> a.setNumber(targetNumber, UUID.randomUUID())));
}
@Test
@ -1128,7 +1121,7 @@ class AccountsManagerTest {
final List<byte[]> usernameHashes = List.of(TestRandomUtil.nextBytes(32), TestRandomUtil.nextBytes(32));
final UsernameReservation result = accountsManager.reserveUsernameHash(account, usernameHashes);
final UsernameReservation result = accountsManager.reserveUsernameHash(account.getIdentifier(IdentityType.ACI), usernameHashes);
assertArrayEquals(usernameHashes.getFirst(), result.reservedUsernameHash());
verify(accounts, times(1)).reserveUsernameHash(eq(account), any(), eq(Duration.ofMinutes(5)));
}
@ -1142,7 +1135,7 @@ class AccountsManagerTest {
final List<byte[]> usernameHashes = List.of(TestRandomUtil.nextBytes(32), oldUsernameHash, TestRandomUtil.nextBytes(32));
final UsernameReservation result = accountsManager.reserveUsernameHash(account, usernameHashes);
final UsernameReservation result = accountsManager.reserveUsernameHash(account.getIdentifier(IdentityType.ACI), usernameHashes);
assertArrayEquals(oldUsernameHash, result.reservedUsernameHash());
verify(accounts, never()).reserveUsernameHash(any(), any(), any());
}
@ -1158,7 +1151,7 @@ class AccountsManagerTest {
.doNothing()
.when(accounts).reserveUsernameHash(any(), any(), any());
final UsernameReservation result = accountsManager.reserveUsernameHash(account, usernameHashes);
final UsernameReservation result = accountsManager.reserveUsernameHash(account.getIdentifier(IdentityType.ACI), usernameHashes);
assertArrayEquals(usernameHashes.getFirst(), result.reservedUsernameHash());
verify(accounts, times(2)).reserveUsernameHash(eq(account), any(), eq(Duration.ofMinutes(5)));
}
@ -1166,21 +1159,23 @@ class AccountsManagerTest {
@Test
void testReserveUsernameHashAsyncNotAvailable() throws UsernameHashNotAvailableException {
final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(accounts.getByAccountIdentifierAsync(account.getUuid())).thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
addRetrievableAccount(account);
doThrow(new UsernameHashNotAvailableException())
.when(accounts).reserveUsernameHash(any(), any(), any());
assertThrows(UsernameHashNotAvailableException.class, () ->
accountsManager.reserveUsernameHash(account, List.of(USERNAME_HASH_1, USERNAME_HASH_2)));
accountsManager.reserveUsernameHash(account.getIdentifier(IdentityType.ACI), List.of(USERNAME_HASH_1, USERNAME_HASH_2)));
}
@Test
void testConfirmReservedUsernameHash() throws UsernameHashNotAvailableException, UsernameReservationNotFoundException {
final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
addRetrievableAccount(account);
setReservationHash(account, USERNAME_HASH_1);
accountsManager.confirmReservedUsernameHash(account, USERNAME_HASH_1, ENCRYPTED_USERNAME_1);
accountsManager.confirmReservedUsernameHash(account.getIdentifier(IdentityType.ACI), USERNAME_HASH_1, ENCRYPTED_USERNAME_1);
verify(accounts).confirmUsernameHash(eq(account), eq(USERNAME_HASH_1), eq(ENCRYPTED_USERNAME_1));
}
@ -1194,62 +1189,72 @@ class AccountsManagerTest {
.doNothing()
.when(accounts).confirmUsernameHash(account, USERNAME_HASH_1, ENCRYPTED_USERNAME_1);
accountsManager.confirmReservedUsernameHash(account, USERNAME_HASH_1, ENCRYPTED_USERNAME_1);
accountsManager.confirmReservedUsernameHash(account.getIdentifier(IdentityType.ACI), USERNAME_HASH_1, ENCRYPTED_USERNAME_1);
verify(accounts, times(2)).confirmUsernameHash(eq(account), eq(USERNAME_HASH_1), eq(ENCRYPTED_USERNAME_1));
}
@Test
void testConfirmReservedHashNameMismatch() {
final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
addRetrievableAccount(account);
setReservationHash(account, USERNAME_HASH_1);
assertThrows(UsernameReservationNotFoundException.class,
() -> accountsManager.confirmReservedUsernameHash(account, USERNAME_HASH_2, ENCRYPTED_USERNAME_2));
() -> accountsManager.confirmReservedUsernameHash(account.getIdentifier(IdentityType.ACI), USERNAME_HASH_2, ENCRYPTED_USERNAME_2));
}
@Test
void testConfirmReservedLapsed() throws UsernameHashNotAvailableException {
final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
addRetrievableAccount(account);
// hash was reserved, but the reservation lapsed and another account took it
setReservationHash(account, USERNAME_HASH_1);
doThrow(new UsernameHashNotAvailableException())
.when(accounts).confirmUsernameHash(account, USERNAME_HASH_1, ENCRYPTED_USERNAME_1);
assertThrows(UsernameHashNotAvailableException.class,
() -> accountsManager.confirmReservedUsernameHash(account, USERNAME_HASH_1, ENCRYPTED_USERNAME_1));
() -> accountsManager.confirmReservedUsernameHash(account.getIdentifier(IdentityType.ACI), USERNAME_HASH_1, ENCRYPTED_USERNAME_1));
assertTrue(account.getUsernameHash().isEmpty());
}
@Test
void testConfirmReservedRetry() throws UsernameHashNotAvailableException, UsernameReservationNotFoundException {
final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
addRetrievableAccount(account);
account.setUsernameHash(USERNAME_HASH_1);
// reserved username already set, should be treated as a replay
accountsManager.confirmReservedUsernameHash(account, USERNAME_HASH_1, ENCRYPTED_USERNAME_1);
verifyNoInteractions(accounts);
accountsManager.confirmReservedUsernameHash(account.getIdentifier(IdentityType.ACI), USERNAME_HASH_1, ENCRYPTED_USERNAME_1);
verify(accounts, never()).confirmUsernameHash(any(), any(), any());
}
@Test
void testConfirmReservedUsernameHashWithNoReservation() throws UsernameHashNotAvailableException {
final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(),
new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
addRetrievableAccount(account);
assertThrows(UsernameReservationNotFoundException.class,
() -> accountsManager.confirmReservedUsernameHash(account, USERNAME_HASH_1, ENCRYPTED_USERNAME_1));
() -> accountsManager.confirmReservedUsernameHash(account.getIdentifier(IdentityType.ACI), USERNAME_HASH_1, ENCRYPTED_USERNAME_1));
verify(accounts, never()).confirmUsernameHash(any(), any(), any());
}
@Test
void testClearUsernameHash() {
final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
addRetrievableAccount(account);
account.setUsernameHash(USERNAME_HASH_1);
accountsManager.clearUsernameHash(account);
accountsManager.clearUsernameHash(account.getIdentifier(IdentityType.ACI));
verify(accounts).clearUsernameHash(eq(account));
}
@Test
void testSetUsernameViaUpdate() {
final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
addRetrievableAccount(account);
assertThrows(AssertionError.class, () -> accountsManager.update(account, a -> a.setUsernameHash(USERNAME_HASH_1)));
assertThrows(AssertionError.class, () ->
accountsManager.update(account.getIdentifier(IdentityType.ACI), a -> a.setUsernameHash(USERNAME_HASH_1)));
}
@Test
@ -1440,4 +1445,12 @@ class AccountsManagerTest {
new MismatchedDevices(Set.of(deviceId), Set.of((byte) (extraDeviceId)), Collections.emptySet())))
);
}
private void addRetrievableAccount(final Account account) {
when(accounts.getByAccountIdentifier(account.getIdentifier(IdentityType.ACI)))
.thenReturn(Optional.of(account));
when(accounts.getByAccountIdentifierAsync(account.getIdentifier(IdentityType.ACI)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
}
}

View File

@ -37,6 +37,7 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.Mockito;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
@ -176,7 +177,7 @@ class AccountsManagerUsernameIntegrationTest {
}
assertThrows(UsernameHashNotAvailableException.class,
() -> accountsManager.reserveUsernameHash(account, usernameHashes));
() -> accountsManager.reserveUsernameHash(account.getIdentifier(IdentityType.ACI), usernameHashes));
assertThat(accountsManager.getByAccountIdentifier(account.getUuid()).orElseThrow().getUsernameHash()).isEmpty();
}
@ -201,7 +202,7 @@ class AccountsManagerUsernameIntegrationTest {
usernameHashes.add(TestRandomUtil.nextBytes(32));
final byte[] username = accountsManager
.reserveUsernameHash(account, usernameHashes)
.reserveUsernameHash(account.getIdentifier(IdentityType.ACI), usernameHashes)
.reservedUsernameHash();
assertArrayEquals(username, availableHash);
@ -214,14 +215,14 @@ class AccountsManagerUsernameIntegrationTest {
// reserve
AccountsManager.UsernameReservation reservation =
accountsManager.reserveUsernameHash(account, List.of(USERNAME_HASH_1));
accountsManager.reserveUsernameHash(account.getIdentifier(IdentityType.ACI), List.of(USERNAME_HASH_1));
assertArrayEquals(reservation.account().getReservedUsernameHash().orElseThrow(), USERNAME_HASH_1);
assertThat(accountsManager.getByUsernameHash(reservation.reservedUsernameHash()).join()).isEmpty();
// confirm
account = accountsManager.confirmReservedUsernameHash(
reservation.account(),
reservation.account().getIdentifier(IdentityType.ACI),
reservation.reservedUsernameHash(),
ENCRYPTED_USERNAME_1);
assertArrayEquals(account.getUsernameHash().orElseThrow(), USERNAME_HASH_1);
@ -232,7 +233,7 @@ class AccountsManagerUsernameIntegrationTest {
.isEqualTo(account.getUuid());
// clear
account = accountsManager.clearUsernameHash(account);
account = accountsManager.clearUsernameHash(account.getIdentifier(IdentityType.ACI));
assertThat(accountsManager.getByUsernameHash(USERNAME_HASH_1).join()).isEmpty();
assertThat(accountsManager.getByAccountIdentifier(account.getUuid()).orElseThrow().getUsernameHash()).isEmpty();
}
@ -243,16 +244,16 @@ class AccountsManagerUsernameIntegrationTest {
Account account = AccountsHelper.createAccount(accountsManager, "+18005551111");
AccountsManager.UsernameReservation reservation =
accountsManager.reserveUsernameHash(account, List.of(USERNAME_HASH_1));
accountsManager.reserveUsernameHash(account.getIdentifier(IdentityType.ACI), List.of(USERNAME_HASH_1));
// confirm
account = accountsManager.confirmReservedUsernameHash(
reservation.account(),
reservation.account().getIdentifier(IdentityType.ACI),
reservation.reservedUsernameHash(),
ENCRYPTED_USERNAME_1);
// clear
account = accountsManager.clearUsernameHash(account);
account = accountsManager.clearUsernameHash(account.getIdentifier(IdentityType.ACI));
assertThat(accountsManager.getByUsernameHash(USERNAME_HASH_1).join()).isEmpty();
assertThat(accountsManager.getByAccountIdentifier(account.getUuid()).orElseThrow().getUsernameHash()).isEmpty();
@ -260,7 +261,7 @@ class AccountsManagerUsernameIntegrationTest {
Account account2 = AccountsHelper.createAccount(accountsManager, "+18005552222");
assertThrows(UsernameHashNotAvailableException.class,
() -> accountsManager.reserveUsernameHash(account2, List.of(USERNAME_HASH_1)),
() -> accountsManager.reserveUsernameHash(account2.getIdentifier(IdentityType.ACI), List.of(USERNAME_HASH_1)),
"account2 should not be able to reserve a held hash");
}
@ -270,7 +271,7 @@ class AccountsManagerUsernameIntegrationTest {
final Account account = AccountsHelper.createAccount(accountsManager, "+18005551111");
AccountsManager.UsernameReservation reservation1 =
accountsManager.reserveUsernameHash(account, List.of(USERNAME_HASH_1));
accountsManager.reserveUsernameHash(account.getIdentifier(IdentityType.ACI), List.of(USERNAME_HASH_1));
long past = Instant.now().minus(Duration.ofMinutes(1)).getEpochSecond();
// force expiration
@ -286,12 +287,12 @@ class AccountsManagerUsernameIntegrationTest {
Account account2 = AccountsHelper.createAccount(accountsManager, "+18005552222");
final AccountsManager.UsernameReservation reservation2 =
accountsManager.reserveUsernameHash(account2, List.of(USERNAME_HASH_1));
accountsManager.reserveUsernameHash(account2.getIdentifier(IdentityType.ACI), List.of(USERNAME_HASH_1));
assertArrayEquals(reservation2.reservedUsernameHash(), USERNAME_HASH_1);
assertThrows(UsernameHashNotAvailableException.class,
() -> accountsManager.confirmReservedUsernameHash(reservation1.account(), USERNAME_HASH_1, ENCRYPTED_USERNAME_1));
account2 = accountsManager.confirmReservedUsernameHash(reservation2.account(), USERNAME_HASH_1, ENCRYPTED_USERNAME_1);
() -> accountsManager.confirmReservedUsernameHash(reservation1.account().getIdentifier(IdentityType.ACI), USERNAME_HASH_1, ENCRYPTED_USERNAME_1));
account2 = accountsManager.confirmReservedUsernameHash(reservation2.account().getIdentifier(IdentityType.ACI), USERNAME_HASH_1, ENCRYPTED_USERNAME_1);
assertEquals(accountsManager.getByUsernameHash(USERNAME_HASH_1).join().orElseThrow().getUuid(), account2.getUuid());
assertArrayEquals(account2.getUsernameHash().orElseThrow(), USERNAME_HASH_1);
}
@ -303,13 +304,13 @@ class AccountsManagerUsernameIntegrationTest {
// Set username hash
final AccountsManager.UsernameReservation reservation1 =
accountsManager.reserveUsernameHash(account, List.of(USERNAME_HASH_1));
accountsManager.reserveUsernameHash(account.getIdentifier(IdentityType.ACI), List.of(USERNAME_HASH_1));
account = accountsManager.confirmReservedUsernameHash(reservation1.account(), USERNAME_HASH_1, ENCRYPTED_USERNAME_1);
account = accountsManager.confirmReservedUsernameHash(reservation1.account().getIdentifier(IdentityType.ACI), USERNAME_HASH_1, ENCRYPTED_USERNAME_1);
// Reserve another hash on the same account
final AccountsManager.UsernameReservation reservation2 =
accountsManager.reserveUsernameHash(account, List.of(USERNAME_HASH_2));
accountsManager.reserveUsernameHash(account.getIdentifier(IdentityType.ACI), List.of(USERNAME_HASH_2));
account = reservation2.account();
@ -318,12 +319,12 @@ class AccountsManagerUsernameIntegrationTest {
assertArrayEquals(account.getEncryptedUsername().orElseThrow(), ENCRYPTED_USERNAME_1);
// Clear the set username hash but not the reserved one
account = accountsManager.clearUsernameHash(account);
account = accountsManager.clearUsernameHash(account.getIdentifier(IdentityType.ACI));
assertThat(account.getReservedUsernameHash()).isPresent();
assertThat(account.getUsernameHash()).isEmpty();
// Confirm second reservation
account = accountsManager.confirmReservedUsernameHash(account, reservation2.reservedUsernameHash(), ENCRYPTED_USERNAME_2);
account = accountsManager.confirmReservedUsernameHash(account.getIdentifier(IdentityType.ACI), reservation2.reservedUsernameHash(), ENCRYPTED_USERNAME_2);
assertArrayEquals(account.getUsernameHash().orElseThrow(), USERNAME_HASH_2);
assertArrayEquals(account.getEncryptedUsername().orElseThrow(), ENCRYPTED_USERNAME_2);
}
@ -333,8 +334,8 @@ class AccountsManagerUsernameIntegrationTest {
throws InterruptedException, UsernameHashNotAvailableException, UsernameReservationNotFoundException {
Account account = AccountsHelper.createAccount(accountsManager, "+18005551111");
final AccountsManager.UsernameReservation reservation1 =
accountsManager.reserveUsernameHash(account, List.of(USERNAME_HASH_1));
account = accountsManager.confirmReservedUsernameHash(reservation1.account(), USERNAME_HASH_1, ENCRYPTED_USERNAME_1);
accountsManager.reserveUsernameHash(account.getIdentifier(IdentityType.ACI), List.of(USERNAME_HASH_1));
account = accountsManager.confirmReservedUsernameHash(reservation1.account().getIdentifier(IdentityType.ACI), USERNAME_HASH_1, ENCRYPTED_USERNAME_1);
// "reclaim" the account by re-registering
Account reclaimed = AccountsHelper.createAccount(accountsManager, "+18005551111");
@ -346,20 +347,29 @@ class AccountsManagerUsernameIntegrationTest {
assertThat(accountsManager.getByUsernameHash(USERNAME_HASH_1).join()).isEmpty();
// confirm it again
accountsManager.confirmReservedUsernameHash(reclaimed, USERNAME_HASH_1, ENCRYPTED_USERNAME_1);
accountsManager.confirmReservedUsernameHash(reclaimed.getIdentifier(IdentityType.ACI), USERNAME_HASH_1, ENCRYPTED_USERNAME_1);
assertThat(accountsManager.getByUsernameHash(USERNAME_HASH_1).join()).isPresent();
}
@Test
public void testUsernameLinks() throws InterruptedException, AccountAlreadyExistsException {
final Account account = AccountsHelper.createAccount(accountsManager, "+18005551111");
public void testUsernameLinks()
throws InterruptedException, AccountAlreadyExistsException, UsernameHashNotAvailableException, UsernameReservationNotFoundException {
final UUID accountIdentifier;
{
final Account account = AccountsHelper.createAccount(accountsManager, "+18005551111");
accounts.create(account, Collections.emptyList());
account.setUsernameHash(TestRandomUtil.nextBytes(16));
accounts.create(account, Collections.emptyList());
accountIdentifier = account.getIdentifier(IdentityType.ACI);
}
final AccountsManager.UsernameReservation reservation =
accountsManager.reserveUsernameHash(accountIdentifier, List.of(USERNAME_HASH_1));
accountsManager.confirmReservedUsernameHash(accountIdentifier, reservation.reservedUsernameHash(), ENCRYPTED_USERNAME_1);
final UUID linkHandle = UUID.randomUUID();
final byte[] encryptedUsername = TestRandomUtil.nextBytes(32);
accountsManager.update(account, a -> a.setUsernameLinkDetails(linkHandle, encryptedUsername));
accountsManager.update(accountIdentifier, account -> account.setUsernameLinkDetails(linkHandle, encryptedUsername));
final Optional<Account> maybeAccount = accountsManager.getByUsernameLinkHandle(linkHandle).join();
assertTrue(maybeAccount.isPresent());
@ -367,17 +377,17 @@ class AccountsManagerUsernameIntegrationTest {
assertArrayEquals(encryptedUsername, maybeAccount.get().getEncryptedUsername().get());
// making some unrelated change and updating account to check that username link data is still there
final Optional<Account> accountToChange = accountsManager.getByAccountIdentifier(account.getUuid());
final Optional<Account> accountToChange = accountsManager.getByAccountIdentifier(accountIdentifier);
assertTrue(accountToChange.isPresent());
accountsManager.update(accountToChange.get(), a -> a.setDiscoverableByPhoneNumber(!a.isDiscoverableByPhoneNumber()));
accountsManager.update(accountToChange.get().getIdentifier(IdentityType.ACI), a -> a.setDiscoverableByPhoneNumber(!a.isDiscoverableByPhoneNumber()));
final Optional<Account> accountAfterChange = accountsManager.getByUsernameLinkHandle(linkHandle).join();
assertTrue(accountAfterChange.isPresent());
assertTrue(accountAfterChange.get().getEncryptedUsername().isPresent());
assertArrayEquals(encryptedUsername, accountAfterChange.get().getEncryptedUsername().get());
// now deleting the link
final Optional<Account> accountToDeleteLink = accountsManager.getByAccountIdentifier(account.getUuid());
accountsManager.update(accountToDeleteLink.orElseThrow(), a -> a.setUsernameLinkDetails(null, null));
final Optional<Account> accountToDeleteLink = accountsManager.getByAccountIdentifier(accountIdentifier);
accountsManager.update(accountToDeleteLink.orElseThrow().getIdentifier(IdentityType.ACI), a -> a.setUsernameLinkDetails(null, null));
assertTrue(accounts.getByUsernameLinkHandle(linkHandle).join().isEmpty());
}
}

View File

@ -184,7 +184,7 @@ public class AddRemoveDeviceIntegrationTest {
assertEquals(1, accountsManager.getByAccountIdentifier(account.getUuid()).orElseThrow().getDevices().size());
final Pair<Account, Device> updatedAccountAndDevice =
accountsManager.addDevice(account, new DeviceSpec(
accountsManager.addDevice(account.getIdentifier(IdentityType.ACI), new DeviceSpec(
"device-name".getBytes(StandardCharsets.UTF_8),
"password",
"OWT",
@ -234,7 +234,7 @@ public class AddRemoveDeviceIntegrationTest {
final String linkDeviceToken = accountsManager.generateLinkDeviceToken(account.getIdentifier(IdentityType.ACI));
final Pair<Account, Device> updatedAccountAndDevice =
accountsManager.addDevice(account, new DeviceSpec(
accountsManager.addDevice(account.getIdentifier(IdentityType.ACI), new DeviceSpec(
"device-name".getBytes(StandardCharsets.UTF_8),
"password",
"OWT",
@ -255,7 +255,7 @@ public class AddRemoveDeviceIntegrationTest {
.size());
assertThrows(LinkDeviceTokenAlreadyUsedException.class,
() -> accountsManager.addDevice(account, new DeviceSpec(
() -> accountsManager.addDevice(account.getIdentifier(IdentityType.ACI), new DeviceSpec(
"device-name".getBytes(StandardCharsets.UTF_8),
"password",
"OWT",
@ -289,7 +289,7 @@ public class AddRemoveDeviceIntegrationTest {
assertEquals(1, accountsManager.getByAccountIdentifier(account.getUuid()).orElseThrow().getDevices().size());
final Pair<Account, Device> updatedAccountAndDevice =
accountsManager.addDevice(account, new DeviceSpec(
accountsManager.addDevice(account.getIdentifier(IdentityType.ACI), new DeviceSpec(
"device-name".getBytes(StandardCharsets.UTF_8),
"password",
"OWT",
@ -307,7 +307,7 @@ public class AddRemoveDeviceIntegrationTest {
final byte addedDeviceId = updatedAccountAndDevice.second().getId();
final Account updatedAccount = accountsManager.removeDevice(updatedAccountAndDevice.first(), addedDeviceId);
final Account updatedAccount = accountsManager.removeDevice(updatedAccountAndDevice.first().getIdentifier(IdentityType.ACI), addedDeviceId);
assertEquals(1, updatedAccount.getDevices().size());
@ -340,7 +340,7 @@ public class AddRemoveDeviceIntegrationTest {
final UUID aci = account.getIdentifier(IdentityType.ACI);
final Pair<Account, Device> updatedAccountAndDevice =
accountsManager.addDevice(account, new DeviceSpec(
accountsManager.addDevice(account.getIdentifier(IdentityType.ACI), new DeviceSpec(
"device-name".getBytes(StandardCharsets.UTF_8),
"password",
"OWT",
@ -362,7 +362,7 @@ public class AddRemoveDeviceIntegrationTest {
.thenReturn(CompletableFuture.failedFuture(new RuntimeException("OH NO")));
assertThrows(RuntimeException.class,
() -> accountsManager.removeDevice(updatedAccountAndDevice.first(), addedDeviceId));
() -> accountsManager.removeDevice(updatedAccountAndDevice.first().getIdentifier(IdentityType.ACI), addedDeviceId));
final Account retrievedAccount = accountsManager.getByAccountIdentifierAsync(aci).join().orElseThrow();
@ -410,7 +410,7 @@ public class AddRemoveDeviceIntegrationTest {
assertEquals(Optional.empty(), displacedFuture.join());
final Pair<Account, Device> updatedAccountAndDevice =
accountsManager.addDevice(account, new DeviceSpec(
accountsManager.addDevice(account.getIdentifier(IdentityType.ACI), new DeviceSpec(
"device-name".getBytes(StandardCharsets.UTF_8),
"password",
"OWT",
@ -451,7 +451,7 @@ public class AddRemoveDeviceIntegrationTest {
final String linkDeviceTokenIdentifier = AccountsManager.getLinkDeviceTokenIdentifier(linkDeviceToken);
final Pair<Account, Device> updatedAccountAndDevice =
accountsManager.addDevice(account, new DeviceSpec(
accountsManager.addDevice(account.getIdentifier(IdentityType.ACI), new DeviceSpec(
"device-name".getBytes(StandardCharsets.UTF_8),
"password",
"OWT",
@ -520,7 +520,7 @@ public class AddRemoveDeviceIntegrationTest {
final String linkDeviceToken = accountsManager.generateLinkDeviceToken(UUID.randomUUID());
final String linkDeviceTokenIdentifier = AccountsManager.getLinkDeviceTokenIdentifier(linkDeviceToken);
accountsManager.addDevice(account, new DeviceSpec(
accountsManager.addDevice(account.getIdentifier(IdentityType.ACI), new DeviceSpec(
"device-name".getBytes(StandardCharsets.UTF_8),
"password",
"OWT",
@ -563,7 +563,7 @@ public class AddRemoveDeviceIntegrationTest {
final String linkDeviceTokenIdentifier = AccountsManager.getLinkDeviceTokenIdentifier(linkDeviceToken);
clock.pin(Instant.ofEpochMilli(0));
accountsManager.addDevice(account, new DeviceSpec(
accountsManager.addDevice(account.getIdentifier(IdentityType.ACI), new DeviceSpec(
"device-name".getBytes(StandardCharsets.UTF_8),
"password",
"OWT",

View File

@ -35,6 +35,7 @@ import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.TestClock;
@ -56,7 +57,7 @@ public class ChangeNumberManagerTest {
updatedPhoneNumberIdentifiersByAccount = new HashMap<>();
when(accountsManager.changeNumber(any(), any(), any(), any(), any(), any())).thenAnswer((Answer<Account>)invocation -> {
final Account account = invocation.getArgument(0, Account.class);
final Account account = accountsManager.getByAccountIdentifier(invocation.getArgument(0)).orElseThrow();
final String number = invocation.getArgument(1, String.class);
final UUID uuid = account.getIdentifier(IdentityType.ACI);
@ -99,11 +100,14 @@ public class ChangeNumberManagerTest {
final UUID accountIdentifier = UUID.randomUUID();
final Account account = mock(Account.class);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
when(account.isIdentifiedBy(any())).thenReturn(false);
when(account.isIdentifiedBy(new AciServiceIdentifier(accountIdentifier))).thenReturn(true);
AccountsHelper.setupMockGet(accountsManager, account);
changeNumberManager.changeNumber(account, targetNumber, pniIdentityKey, ecSignedPreKeys, kemLastResortPreKeys, Collections.emptyList(), Collections.emptyMap(), null);
verify(accountsManager).changeNumber(account, targetNumber, pniIdentityKey, ecSignedPreKeys, kemLastResortPreKeys, Collections.emptyMap());
verify(accountsManager).changeNumber(accountIdentifier, targetNumber, pniIdentityKey, ecSignedPreKeys, kemLastResortPreKeys, Collections.emptyMap());
verify(messageSender, never()).sendMessages(eq(account), any(), any(), any(), any(), any());
}
@ -154,6 +158,8 @@ public class ChangeNumberManagerTest {
final IncomingMessage incomingMessage =
new IncomingMessage(1, linkedDeviceId, linkedDeviceRegistrationId, new byte[] { 1 });
AccountsHelper.setupMockGet(accountsManager, account);
changeNumberManager.changeNumber(account,
targetNumber,
pniIdentityKey,
@ -163,7 +169,7 @@ public class ChangeNumberManagerTest {
registrationIds,
null);
verify(accountsManager).changeNumber(account,
verify(accountsManager).changeNumber(aci,
targetNumber,
pniIdentityKey,
ecSignedPreKeys,

View File

@ -31,7 +31,6 @@ import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;

View File

@ -115,7 +115,7 @@ class MessagePersisterTest {
when(accountsManager.getByAccountIdentifierAsync(DESTINATION_ACCOUNT_UUID))
.thenReturn(CompletableFuture.completedFuture(Optional.of(destinationAccount)));
when(accountsManager.removeDevice(any(), anyByte()))
.thenAnswer(invocation -> invocation.getArgument(0));
.thenAnswer(invocation -> accountsManager.getByAccountIdentifier(invocation.getArgument(0)).orElseThrow());
when(destinationAccount.getUuid()).thenReturn(DESTINATION_ACCOUNT_UUID);
when(destinationAccount.getIdentifier(IdentityType.ACI)).thenReturn(DESTINATION_ACCOUNT_UUID);
@ -465,7 +465,7 @@ class MessagePersisterTest {
assertTimeoutPreemptively(Duration.ofSeconds(1), () ->
messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE, Tags.empty()).block());
verify(accountsManager, exactly()).removeDevice(destinationAccount, DESTINATION_DEVICE_ID);
verify(accountsManager, exactly()).removeDevice(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID);
}
@Test
@ -552,7 +552,7 @@ class MessagePersisterTest {
when(destinationAccount.getDevices()).thenReturn(List.of(primary, activeA, inactiveB, inactiveC, activeD, destination));
when(messagesManager.persistMessages(any(UUID.class), any(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build());
when(accountsManager.removeDevice(destinationAccount, DESTINATION_DEVICE_ID)).thenThrow(new RuntimeException());
when(accountsManager.removeDevice(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID)).thenThrow(new RuntimeException());
assertThrows(RuntimeException.class, () -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE, Tags.empty()).block());
}

View File

@ -8,7 +8,6 @@ package org.whispersystems.textsecuregcm.tests.util;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockingDetails;
import static org.mockito.Mockito.when;
@ -18,16 +17,21 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
import org.junit.platform.commons.util.StringUtils;
import org.mockito.MockingDetails;
import org.mockito.stubbing.Stubbing;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.DeviceAttributes;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
@ -64,8 +68,8 @@ public class AccountsHelper {
/**
* Sets up stubbing for:
* <ul>
* <li>{@link AccountsManager#update(Account, Consumer)}</li>
* <li>{@link AccountsManager#updateDevice(Account, byte, Consumer)}</li>
* <li>{@link AccountsManager#update(UUID, Consumer)}</li>
* <li>{@link AccountsManager#updateDevice(UUID, byte, Consumer)}</li>
* </ul>
*
* with multiple calls to the {@link Consumer<Account>}. This simulates retries from {@link org.whispersystems.textsecuregcm.storage.ContestedOptimisticLockException}.
@ -80,67 +84,127 @@ public class AccountsHelper {
*/
@SuppressWarnings("unchecked")
public static void setupMockUpdateWithRetries(final AccountsManager mockAccountsManager, final int retryCount) {
when(mockAccountsManager.update(any(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class);
when(mockAccountsManager.update(any(UUID.class), any())).thenAnswer(invocation -> {
final UUID accountIdentifier = invocation.getArgument(0, UUID.class);
final Account account = mockAccountsManager.getByAccountIdentifier(accountIdentifier).orElseThrow();
for (int i = 0; i < retryCount; i++) {
answer.getArgument(1, Consumer.class).accept(account);
invocation.getArgument(1, Consumer.class).accept(account);
}
return account;
});
when(mockAccountsManager.update(any(Account.class), any())).thenAnswer(invocation -> {
final Account account = invocation.getArgument(0);
for (int i = 0; i < retryCount; i++) {
invocation.getArgument(1, Consumer.class).accept(account);
}
return copyAndMarkStale(account);
});
when(mockAccountsManager.updateDevice(any(), anyByte(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class);
when(mockAccountsManager.updateDevice(any(UUID.class), anyByte(), any())).thenAnswer(answer -> {
final UUID accountIdentifier = answer.getArgument(0, UUID.class);
final Account account = mockAccountsManager.getByAccountIdentifier(accountIdentifier).orElseThrow();
final byte deviceId = answer.getArgument(1, Byte.class);
for (int i = 0; i < retryCount; i++) {
account.getDevice(deviceId).ifPresent(answer.getArgument(2, Consumer.class));
}
return copyAndMarkStale(account);
return account;
});
}
@SuppressWarnings("unchecked")
private static void setupMockUpdate(final AccountsManager mockAccountsManager, final boolean markStale) {
when(mockAccountsManager.update(any(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class);
answer.getArgument(1, Consumer.class).accept(account);
when(mockAccountsManager.update(any(UUID.class), any())).thenAnswer(invocation -> {
final UUID accountIdentifier = invocation.getArgument(0, UUID.class);
final Account account = mockAccountsManager.getByAccountIdentifier(accountIdentifier).orElseThrow();
invocation.getArgument(1, Consumer.class).accept(account);
return account;
});
when(mockAccountsManager.update(any(Account.class), any())).thenAnswer(invocation -> {
final Account account = invocation.getArgument(0);
invocation.getArgument(1, Consumer.class).accept(account);
return markStale ? copyAndMarkStale(account) : account;
});
when(mockAccountsManager.updateDevice(any(), anyByte(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class);
final byte deviceId = answer.getArgument(1, Byte.class);
account.getDevice(deviceId).ifPresent(answer.getArgument(2, Consumer.class));
when(mockAccountsManager.updateDevice(any(), anyByte(), any())).thenAnswer(invocation -> {
final UUID accountIdentifier = invocation.getArgument(0, UUID.class);
final Account account = mockAccountsManager.getByAccountIdentifier(accountIdentifier).orElseThrow();
return markStale ? copyAndMarkStale(account) : account;
final byte deviceId = invocation.getArgument(1, Byte.class);
account.getDevice(deviceId).ifPresent(invocation.getArgument(2, Consumer.class));
return account;
});
when(mockAccountsManager.updateDeviceLastSeen(any(), any(), anyLong())).thenAnswer(answer -> {
answer.getArgument(1, Device.class).setLastSeen(answer.getArgument(2, Long.class));
return mockAccountsManager.update(answer.getArgument(0, Account.class), account -> {});
when(mockAccountsManager.updateDeviceLastSeen(any(), any(), anyLong())).thenAnswer(invocation -> {
final UUID accountIdentifier = invocation.getArgument(0, UUID.class);
final Account account = mockAccountsManager.getByAccountIdentifier(accountIdentifier).orElseThrow();
final Device device = account.getDevice(invocation.getArgument(1, Device.class).getId()).orElseThrow();
device.setLastSeen(invocation.getArgument(2, Long.class));
return mockAccountsManager.update(accountIdentifier, _ -> {});
});
}
public static void setupMockGet(final AccountsManager mockAccountsManager, final Set<Account> mockAccounts) {
when(mockAccountsManager.getByAccountIdentifier(any(UUID.class))).thenAnswer(answer -> {
public static void setupMockGet(final AccountsManager mockAccountsManager, final Account account) {
if (account.getUuid() != null || account.getIdentifier(IdentityType.ACI) != null) {
final UUID accountIdentifier =
Objects.requireNonNullElseGet(account.getIdentifier(IdentityType.ACI), account::getUuid);
final UUID uuid = answer.getArgument(0, UUID.class);
when(mockAccountsManager.getByAccountIdentifier(accountIdentifier))
.thenReturn(Optional.of(account));
return mockAccounts.stream()
.filter(account -> uuid.equals(account.getUuid()))
.findFirst()
.map(account -> {
try {
return copyAndMarkStale(account);
} catch (final Exception e) {
throw new RuntimeException(e);
}
});
});
when(mockAccountsManager.getByAccountIdentifierAsync(accountIdentifier))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
when(mockAccountsManager.getByServiceIdentifier(new AciServiceIdentifier(accountIdentifier)))
.thenReturn(Optional.of(account));
when(mockAccountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(accountIdentifier)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
}
if (account.getPhoneNumberIdentifier() != null || account.getIdentifier(IdentityType.PNI) != null) {
final UUID phoneNumberIdentifier =
Objects.requireNonNullElseGet(account.getIdentifier(IdentityType.PNI), account::getPhoneNumberIdentifier);
when(mockAccountsManager.getByPhoneNumberIdentifier(phoneNumberIdentifier))
.thenReturn(Optional.of(account));
when(mockAccountsManager.getByPhoneNumberIdentifierAsync(phoneNumberIdentifier))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
when(mockAccountsManager.getByServiceIdentifier(new PniServiceIdentifier(phoneNumberIdentifier)))
.thenReturn(Optional.of(account));
when(mockAccountsManager.getByServiceIdentifierAsync(new PniServiceIdentifier(phoneNumberIdentifier)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
}
if (StringUtils.isNotBlank(account.getNumber())) {
when(mockAccountsManager.getByE164(account.getNumber())).thenReturn(Optional.of(account));
}
account.getUsernameHash().ifPresent(usernameHash -> when(mockAccountsManager.getByUsernameHash(usernameHash))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account))));
if (account.getUsernameLinkHandle() != null) {
when(mockAccountsManager.getByUsernameLinkHandle(account.getUsernameLinkHandle()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
}
}
private static Account copyAndMarkStale(Account account) throws IOException {
@ -192,10 +256,6 @@ public class AccountsHelper {
return updatedAccount;
}
public static Account eqUuid(Account value) {
return argThat(other -> other.getUuid().equals(value.getUuid()));
}
public static Account createAccount(final AccountsManager accountsManager, final String e164)
throws InterruptedException {

View File

@ -166,7 +166,7 @@ public class AuthHelper {
when(VALID_ACCOUNT_TWO.getUuid()).thenReturn(VALID_UUID_TWO);
when(VALID_ACCOUNT_TWO.getPhoneNumberIdentifier()).thenReturn(VALID_PNI_TWO);
when(VALID_ACCOUNT_TWO.getIdentifier(IdentityType.ACI)).thenReturn(VALID_UUID_TWO);
when(VALID_ACCOUNT_TWO.getPhoneNumberIdentifier()).thenReturn(VALID_PNI_TWO);
when(VALID_ACCOUNT_TWO.getIdentifier(IdentityType.PNI)).thenReturn(VALID_PNI_TWO);
when(UNDISCOVERABLE_ACCOUNT.getNumber()).thenReturn(UNDISCOVERABLE_NUMBER);
when(UNDISCOVERABLE_ACCOUNT.getUuid()).thenReturn(UNDISCOVERABLE_UUID);
when(UNDISCOVERABLE_ACCOUNT.getPhoneNumberIdentifier()).thenReturn(UNDISCOVERABLE_PNI);

View File

@ -16,12 +16,14 @@ import static org.mockito.Mockito.when;
import java.time.Clock;
import java.time.Instant;
import java.time.ZoneId;
import java.util.UUID;
import java.util.stream.Stream;
import net.sourceforge.argparse4j.inf.Namespace;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import reactor.core.publisher.Flux;
@ -64,10 +66,15 @@ class RemoveExpiredAccountsCommandTest {
final RemoveExpiredAccountsCommand removeExpiredAccountsCommand =
new TestRemoveExpiredAccountsCommand(clock, accountsManager, isDryRun);
final UUID activeAccountIdentifier = UUID.randomUUID();
final UUID expiredAccountIdentifier = UUID.randomUUID();
final Account activeAccount = mock(Account.class);
when(activeAccount.getIdentifier(IdentityType.ACI)).thenReturn(activeAccountIdentifier);
when(activeAccount.getLastSeen()).thenReturn(clock.instant().toEpochMilli());
final Account expiredAccount = mock(Account.class);
when(expiredAccount.getIdentifier(IdentityType.ACI)).thenReturn(expiredAccountIdentifier);
when(expiredAccount.getLastSeen())
.thenReturn(clock.instant().minus(RemoveExpiredAccountsCommand.MAX_IDLE_DURATION).minusMillis(1).toEpochMilli());
@ -76,8 +83,8 @@ class RemoveExpiredAccountsCommandTest {
if (isDryRun) {
verify(accountsManager, never()).delete(any(), any());
} else {
verify(accountsManager).delete(expiredAccount, AccountsManager.DeletionReason.EXPIRED);
verify(accountsManager, never()).delete(eq(activeAccount), any());
verify(accountsManager).delete(expiredAccountIdentifier, AccountsManager.DeletionReason.EXPIRED);
verify(accountsManager, never()).delete(eq(activeAccountIdentifier), any());
}
}

View File

@ -23,6 +23,7 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Consumer;
import java.util.stream.IntStream;
@ -31,6 +32,7 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.util.TestClock;
@ -76,7 +78,10 @@ class RemoveExpiredUsernameHoldsCommandTest {
final RemoveExpiredUsernameHoldsCommand removeExpiredUsernameHoldsCommand =
new TestRemoveExpiredUsernameHoldsCommand(clock, accountsManager, isDryRun);
final UUID hasHoldsAccountIdentifier = UUID.randomUUID();
final Account hasHolds = mock(Account.class);
when(hasHolds.getIdentifier(IdentityType.ACI)).thenReturn(hasHoldsAccountIdentifier);
final List<Account.UsernameHold> originalHolds = List.of(
// expired
new Account.UsernameHold(TestRandomUtil.nextBytes(32), Instant.EPOCH.getEpochSecond()),
@ -92,7 +97,7 @@ class RemoveExpiredUsernameHoldsCommandTest {
verifyNoInteractions(accountsManager);
} else {
ArgumentCaptor<Consumer<Account>> updaterCaptor = ArgumentCaptor.forClass(Consumer.class);
verify(accountsManager, times(1)).update(eq(hasHolds), updaterCaptor.capture());
verify(accountsManager, times(1)).update(eq(hasHoldsAccountIdentifier), updaterCaptor.capture());
final Consumer<Account> consumer = updaterCaptor.getValue();
consumer.accept(hasHolds);
verify(hasHolds, times(1)).setUsernameHolds(argThat(holds ->

View File

@ -124,7 +124,7 @@ class UnlinkDevicesWithIdlePrimaryCommandTest {
accountWithIdlePrimaryAndLinkedDevice));
if (!isDryRun) {
verify(accountsManager).removeDevice(accountWithIdlePrimaryAndLinkedDevice, linkedDeviceId);
verify(accountsManager).removeDevice(accountWithIdlePrimaryAndLinkedDevice.getIdentifier(IdentityType.ACI), linkedDeviceId);
}
verifyNoMoreInteractions(accountsManager);