Revert "Mirror message reads/acknowledgements via FoundationDbMessageStream"

This reverts commit 4e52317e26.
This commit is contained in:
Jon Chambers 2026-06-24 12:58:20 -04:00
parent 4e52317e26
commit 28aefe0ebe
8 changed files with 241 additions and 347 deletions

View File

@ -1,155 +0,0 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import java.util.HashSet;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Flow;
import java.util.function.BiConsumer;
import com.google.common.annotations.VisibleForTesting;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.storage.foundationdb.FoundationDbMessageStream;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import reactor.adapter.JdkFlowAdapter;
import reactor.core.publisher.BaseSubscriber;
/// A temporary message stream that can mirror message acknowledgements (deletion requests) to FoundationDB
public class AcknowledgementMirroringMessageStream implements MessageStream {
private final RedisDynamoDbMessageStream redisDynamoDbMessageStream;
private final FoundationDbMessageStream foundationDbMessageStream;
private final Set<UUID> messagesPendingAcknowledgement = new HashSet<>();
private static final int HIGH_PENDING_ACKNOWLEDGEMENT_COUNT_WARNING_THRESHOLD = 1024;
@VisibleForTesting
static final int MAX_PENDING_ACKNOWLEDGEMENTS = 8192;
@VisibleForTesting
static final int FOUNDATIONDB_REQUEST_SIZE = 100;
private static final Counter HIGH_PENDING_ACKNOWLEDGEMENT_COUNT_WARNING_COUNTER =
Metrics.counter(MetricsUtil.name(AcknowledgementMirroringMessageStream.class, "highPendingAcknowledgementCount"));
private static final Counter MAX_PENDING_ACKNOWLEDGEMENT_LIMIT_BREACHED_COUNTER =
Metrics.counter(MetricsUtil.name(AcknowledgementMirroringMessageStream.class, "maxPendingAcknowledgementLimitBreached"));
private static class FoundationDbSubscriber extends BaseSubscriber<MessageStreamEntry.Envelope> {
private final BiConsumer<UUID, Long> mirroredMessageHandler;
private long mirroredEntriesDelivered = 0;
private long requested = 0;
private FoundationDbSubscriber(final BiConsumer<UUID, Long> mirroredMessageHandler) {
this.mirroredMessageHandler = mirroredMessageHandler;
}
public void handleRedisDynamoDbMessageStreamEntry(final MessageStreamEntry messageStreamEntry) {
final boolean isMirroredMessage = switch (messageStreamEntry) {
case MessageStreamEntry.Envelope envelopeEntry ->
UUIDUtil.fromByteString(envelopeEntry.message().getServerGuid()).version() == 8;
case MessageStreamEntry.QueueEmpty _ -> false;
};
if (!isMirroredMessage) {
return;
}
final boolean requestMore;
synchronized (this) {
requestMore = ++mirroredEntriesDelivered > requested;
if (requestMore) {
requested += FOUNDATIONDB_REQUEST_SIZE;
}
}
if (requestMore) {
request(FOUNDATIONDB_REQUEST_SIZE);
}
}
@Override
protected void hookOnNext(final MessageStreamEntry.Envelope envelope) {
mirroredMessageHandler.accept(
UUIDUtil.fromByteString(envelope.message().getServerGuid()), envelope.message().getServerTimestamp());
}
}
public AcknowledgementMirroringMessageStream(final RedisDynamoDbMessageStream redisDynamoDbMessageStream,
final FoundationDbMessageStream foundationDbMessageStream) {
this.redisDynamoDbMessageStream = redisDynamoDbMessageStream;
this.foundationDbMessageStream = foundationDbMessageStream;
}
@Override
public Flow.Publisher<MessageStreamEntry> getMessages() {
final FoundationDbSubscriber subscriber = new FoundationDbSubscriber(this::handleMirroredMessageAcknowledged);
JdkFlowAdapter.flowPublisherToFlux(foundationDbMessageStream.getMessages())
.filter(messageStreamEntry -> messageStreamEntry instanceof MessageStreamEntry.Envelope)
.cast(MessageStreamEntry.Envelope.class)
.subscribe(subscriber);
return JdkFlowAdapter.publisherToFlowPublisher(
JdkFlowAdapter.flowPublisherToFlux(redisDynamoDbMessageStream.getMessages())
// Mirror demand and termination signals to the FoundationDB message stream
.doOnNext(subscriber::handleRedisDynamoDbMessageStreamEntry)
.doOnComplete(subscriber::dispose));
}
@Override
public CompletableFuture<Void> acknowledgeMessage(final UUID messageGuid, final long serverTimestamp) {
// All messages stored in FoundationDB use version 8 UUIDs; if a message has a version 4 UUID, then it only exists
// in Redis/DynamoDB
if (messageGuid.version() == 8) {
handleMirroredMessageAcknowledged(messageGuid, serverTimestamp);
}
return redisDynamoDbMessageStream.acknowledgeMessage(messageGuid, serverTimestamp);
}
private void handleMirroredMessageAcknowledged(final UUID messageGuid, final long serverTimestamp) {
// Before we can acknowledge and delete the message in FoundationDB, two things need to happen:
//
// 1. The client needs to acknowledge the message on the Redis/DDB stream.
// 2. The message needs to pass through the FoundationDB message stream pipeline because it needs to be registered
// in the acknowledged message buffer machinery before it can be deleted.
//
// We can't guarantee the order in which these two things happen, so we use a synchronized set to coordinate and
// determine when both have happened.
synchronized (messagesPendingAcknowledgement) {
if (messagesPendingAcknowledgement.remove(messageGuid)) {
// One of the two streams had already acknowledged this message, and this is the second acknowledgement. Now we
// can pass the acknowledgement along to FoundationDB.
foundationDbMessageStream.acknowledgeMessage(messageGuid, serverTimestamp);
} else {
// Either this message came from FoundationDB and got auto-acknowledged before getting explicitly acknowledged
// on the Redis/DynamoDB stream or the message got explicitly acknowledged on the Redis/DynamoDB stream before
// passing through the FoundationDB stream. Either way, wait for both streams to have done their part.
if (messagesPendingAcknowledgement.add(messageGuid)) {
if (messagesPendingAcknowledgement.size() == HIGH_PENDING_ACKNOWLEDGEMENT_COUNT_WARNING_THRESHOLD) {
HIGH_PENDING_ACKNOWLEDGEMENT_COUNT_WARNING_COUNTER.increment();
}
if (messagesPendingAcknowledgement.size() > MAX_PENDING_ACKNOWLEDGEMENTS) {
MAX_PENDING_ACKNOWLEDGEMENT_LIMIT_BREACHED_COUNTER.increment();
throw new IllegalStateException("Too many pending acknowledgements");
}
}
}
}
}
}

View File

@ -0,0 +1,66 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Flow;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.foundationdb.FoundationDbMessageStore;
/// A temporary message stream that can mirror message acknowledgements (deletion requests) to FoundationDB
public class DeletionMirroringRedisDynamoDbMessageStream implements MessageStream {
private final RedisDynamoDbMessageStream redisDynamoDbMessageStream;
private final FoundationDbMessageStore foundationDbMessageStore;
private final ExperimentEnrollmentManager experimentEnrollmentManager;
private final AciServiceIdentifier accountIdentifier;
private final byte deviceId;
private static final Logger logger = LoggerFactory.getLogger(DeletionMirroringRedisDynamoDbMessageStream.class);
public DeletionMirroringRedisDynamoDbMessageStream(final RedisDynamoDbMessageStream redisDynamoDbMessageStream,
final FoundationDbMessageStore foundationDbMessageStore,
final ExperimentEnrollmentManager experimentEnrollmentManager,
final UUID accountIdentifier,
final byte deviceId) {
this.redisDynamoDbMessageStream = redisDynamoDbMessageStream;
this.foundationDbMessageStore = foundationDbMessageStore;
this.experimentEnrollmentManager = experimentEnrollmentManager;
this.accountIdentifier = new AciServiceIdentifier(accountIdentifier);
this.deviceId = deviceId;
}
@Override
public Flow.Publisher<MessageStreamEntry> getMessages() {
return redisDynamoDbMessageStream.getMessages();
}
@Override
public CompletableFuture<Void> acknowledgeMessage(final UUID messageGuid, final long serverTimestamp) {
// All messages stored in FoundationDB use version 8 UUIDs; if a message has a version 4 UUID, then it only exists
// in Redis/DynamoDB
if (messageGuid.version() == 8 &&
experimentEnrollmentManager.isEnrolled(accountIdentifier.uuid(), MessagesManager.MIRROR_DELETIONS_EXPERIMENT_NAME)) {
foundationDbMessageStore.delete(accountIdentifier, deviceId, messageGuid)
.whenComplete((_, throwable) -> {
if (throwable != null) {
logger.warn("Failed to delete message {}/{}/{} from FoundationDb", accountIdentifier.uuid(), deviceId,
messageGuid, throwable);
}
});
}
return redisDynamoDbMessageStream.acknowledgeMessage(messageGuid, serverTimestamp);
}
}

View File

@ -38,7 +38,6 @@ import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.push.RedisMessageAvailabilityManager;
import org.whispersystems.textsecuregcm.storage.foundationdb.FoundationDbMessageStore;
import org.whispersystems.textsecuregcm.storage.foundationdb.FoundationDbMessageStream;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import reactor.core.observability.micrometer.Micrometer;
import reactor.core.publisher.Flux;
@ -280,15 +279,12 @@ public class MessagesManager {
}
public MessageStream getMessages(final UUID destinationUuid, final Device destinationDevice) {
final RedisDynamoDbMessageStream redisDynamoDbMessageStream =
new RedisDynamoDbMessageStream(messagesDynamoDb, messagesCache, redisMessageAvailabilityManager,
destinationUuid, destinationDevice);
return experimentEnrollmentManager.isEnrolled(destinationUuid, MIRROR_DELETIONS_EXPERIMENT_NAME)
? new AcknowledgementMirroringMessageStream(
redisDynamoDbMessageStream,
foundationDbMessageStore.getMessages(new AciServiceIdentifier(destinationUuid), destinationDevice.getId()))
: redisDynamoDbMessageStream;
return new DeletionMirroringRedisDynamoDbMessageStream(
new RedisDynamoDbMessageStream(messagesDynamoDb, messagesCache, redisMessageAvailabilityManager, destinationUuid, destinationDevice),
foundationDbMessageStore,
experimentEnrollmentManager,
destinationUuid,
destinationDevice.getId());
}
Publisher<Envelope> getMessagesForDevice(final UUID destinationUuid, final Device destinationDevice) {

View File

@ -9,9 +9,6 @@ import com.apple.foundationdb.tuple.Versionstamp;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.hash.Hashing;
import io.dropwizard.util.DataSize;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import java.time.Clock;
import java.time.Duration;
import java.util.ArrayList;
@ -25,9 +22,14 @@ import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessageStream;
import org.whispersystems.textsecuregcm.util.Conversions;
import org.whispersystems.textsecuregcm.util.Util;
@ -69,6 +71,12 @@ public class FoundationDbMessageStore {
private static final Timer INSERT_MESSAGE_BATCH_TIMER =
Metrics.timer(MetricsUtil.name(FoundationDbMessageStore.class, "insertMessageBatchTimer"));
private static final Counter DELETE_MESSAGE_COUNTER =
Metrics.counter(MetricsUtil.name(FoundationDbMessageStore.class, "deleteMessage"));
private static final Timer DELETE_MESSAGE_TIMER =
Metrics.timer(MetricsUtil.name(FoundationDbMessageStore.class, "deleteMessageTimer"));
/// Result of inserting a message for a particular device
///
/// @param versionstamp the versionstamp of the transaction in which this device's message was inserted, empty
@ -316,6 +324,27 @@ public class FoundationDbMessageStore {
});
}
// Note that this method is intended only for initial migration support; in general, callers should clear messages
// by acknowledging messages via a `FoundationDbMessageStream`.
public CompletableFuture<Void> delete(final AciServiceIdentifier aci, final byte deviceId, final UUID messageGuid) {
return delete(aci, deviceId, versionstampUUIDCipher.decryptVersionstamp(messageGuid, aci.uuid(), deviceId));
}
private CompletableFuture<Void> delete(final AciServiceIdentifier aci, final byte deviceId, final Versionstamp versionstamp) {
final Timer.Sample sample = Timer.start();
final byte[] messageKey = getDeviceQueueSubspace(aci, deviceId).pack(Tuple.from(versionstamp));
return databasesByEpoch[getConfigurationEpoch(versionstamp)][getShardId(versionstamp)].runAsync(transaction -> {
transaction.clear(messageKey);
return CompletableFuture.completedFuture(null);
})
.thenRun(() -> {
sample.stop(DELETE_MESSAGE_TIMER);
DELETE_MESSAGE_COUNTER.increment();
});
}
public void clearAll(final AciServiceIdentifier aci) {
doForAllDatabasesWithMessages(aci, database -> database.run(transaction -> {
transaction.clear(getAccountSubspace(aci).range());
@ -338,14 +367,14 @@ public class FoundationDbMessageStore {
.forEach(action);
}
public FoundationDbMessageStream getMessages(final AciServiceIdentifier aci, final byte deviceId) {
return getMessages(aci, deviceId, FoundationDbMessageStream.DEFAULT_MAX_MESSAGES_PER_SCAN,
public MessageStream getMessages(final AciServiceIdentifier aci, final Device destinationDevice) {
return getMessages(aci, destinationDevice, FoundationDbMessageStream.DEFAULT_MAX_MESSAGES_PER_SCAN,
FoundationDbMessageStream.DEFAULT_MAX_UNACKNOWLEDGED_MESSAGES, Util.NOOP);
}
@VisibleForTesting
FoundationDbMessageStream getMessages(final AciServiceIdentifier aci,
final byte deviceId,
MessageStream getMessages(final AciServiceIdentifier aci,
final Device destinationDevice,
final int maxMessagesPerScan,
final int maxUnacknowledgedMessages,
final Runnable doAfterCleanup) {
@ -359,10 +388,10 @@ public class FoundationDbMessageStore {
: null;
}
return new FoundationDbMessageStream(getDeviceQueueSubspace(aci, deviceId),
return new FoundationDbMessageStream(getDeviceQueueSubspace(aci, destinationDevice.getId()),
getMessagesAvailableWatchKey(aci),
databasesForQueueByEpoch,
new MessageGuidCodec(aci.uuid(), deviceId, versionstampUUIDCipher),
new MessageGuidCodec(aci.uuid(), destinationDevice.getId(), versionstampUUIDCipher),
maxMessagesPerScan,
maxUnacknowledgedMessages,
doAfterCleanup);

View File

@ -51,15 +51,8 @@ public class FoundationDbMessageStream implements MessageStream {
private final Map<Database, AcknowledgedMessageBuffer> acknowledgedMessageBuffersByDatabase;
private final Counter messageReadCounter =
Metrics.counter(name(FoundationDbMessageStream.class, "messagesRead"));
private final Counter messageAcknowledgedCounter =
Metrics.counter(name(FoundationDbMessageStream.class, "messagesAcknowledged"));
private final Counter staleEphemeralMessagesCounter =
Metrics.counter(name(FoundationDbMessageStream.class, "staleEphemeralMessages"));
private final Counter staleEphemeralMessagesCounter = Metrics.counter(
name(FoundationDbMessageStream.class, "staleEphemeralMessages"));
static final int DEFAULT_MAX_MESSAGES_PER_SCAN = 1024;
@VisibleForTesting
static final int DEFAULT_MAX_UNACKNOWLEDGED_MESSAGES = 16_384;
@ -123,11 +116,6 @@ public class FoundationDbMessageStream implements MessageStream {
sink.next(messageStreamEntry);
})
.map(fdbMessageStreamEntry -> fdbMessageStreamEntry.toMessageStreamEntry(messageGuidCodec))
.doOnNext(messageStreamEntry -> {
if (messageStreamEntry instanceof MessageStreamEntry.Envelope) {
messageReadCounter.increment();
}
})
.doFinally(_ -> flushAllAcknowledgedMessages().thenRun(doAfterCleanup));
}
@ -237,8 +225,6 @@ public class FoundationDbMessageStream implements MessageStream {
final Versionstamp versionstamp = messageGuidCodec.decodeMessageGuid(messageGuid);
getAcknowledgedMessageBuffer(versionstamp).acknowledgeMessage(versionstamp);
messageAcknowledgedCounter.increment();
return CompletableFuture.completedFuture(null);
}

View File

@ -1,146 +0,0 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.storage.foundationdb.FoundationDbMessageStream;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import reactor.adapter.JdkFlowAdapter;
import reactor.core.publisher.Flux;
class AcknowledgementMirroringMessageStreamTest {
private RedisDynamoDbMessageStream redisDynamoDbMessageStream;
private FoundationDbMessageStream foundationDbMessageStream;
private AcknowledgementMirroringMessageStream mirroringMessageStream;
private static long serialMessageTimestamp = 0;
private enum UUIDVersion {
V4,
V8
}
@BeforeEach
void setUp() {
redisDynamoDbMessageStream = mock(RedisDynamoDbMessageStream.class);
when(redisDynamoDbMessageStream.acknowledgeMessage(any(), anyLong()))
.thenReturn(CompletableFuture.completedFuture(null));
foundationDbMessageStream = mock(FoundationDbMessageStream.class);
when(foundationDbMessageStream.acknowledgeMessage(any(), anyLong()))
.thenReturn(CompletableFuture.completedFuture(null));
mirroringMessageStream = new AcknowledgementMirroringMessageStream(
redisDynamoDbMessageStream,
foundationDbMessageStream);
}
@ParameterizedTest
@EnumSource(UUIDVersion.class)
void acknowledgeMessages(final UUIDVersion uuidVersion) {
final List<MessageStreamEntry> entries = new ArrayList<>();
{
for (int i = 0; i < AcknowledgementMirroringMessageStream.FOUNDATIONDB_REQUEST_SIZE + 1; i++) {
entries.add(generateEnvelopeEntry(uuidVersion));
}
entries.add(new MessageStreamEntry.QueueEmpty());
}
when(redisDynamoDbMessageStream.getMessages())
.thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.fromIterable(entries)));
when(foundationDbMessageStream.getMessages())
.thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.fromIterable(
uuidVersion == UUIDVersion.V8 ? entries : Collections.emptyList())));
JdkFlowAdapter.flowPublisherToFlux(mirroringMessageStream.getMessages())
.doOnNext(entry -> {
if (entry instanceof MessageStreamEntry.Envelope(MessageProtos.Envelope message)) {
mirroringMessageStream.acknowledgeMessage(UUIDUtil.fromByteString(message.getServerGuid()),
message.getServerTimestamp());
}
})
.takeUntil(entry -> entry instanceof MessageStreamEntry.QueueEmpty)
.then()
.block();
if (uuidVersion == UUIDVersion.V8) {
entries.stream()
.filter(streamEntry -> streamEntry instanceof MessageStreamEntry.Envelope)
.map(streamEntry -> ((MessageStreamEntry.Envelope) streamEntry).message())
.forEach(envelope -> verify(foundationDbMessageStream).acknowledgeMessage(
UUIDUtil.fromByteString(envelope.getServerGuid()), envelope.getServerTimestamp()));
} else {
verify(foundationDbMessageStream, never()).acknowledgeMessage(any(), anyLong());
}
}
@Test
void acknowledgeMessagesOverflow() {
final List<MessageStreamEntry> entries = new ArrayList<>();
{
for (int i = 0; i < AcknowledgementMirroringMessageStream.MAX_PENDING_ACKNOWLEDGEMENTS + 1; i++) {
entries.add(generateEnvelopeEntry(UUIDVersion.V8));
}
entries.add(new MessageStreamEntry.QueueEmpty());
}
when(redisDynamoDbMessageStream.getMessages())
.thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.fromIterable(entries)));
when(foundationDbMessageStream.getMessages())
.thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.empty()));
assertThrows(IllegalStateException.class,
() -> JdkFlowAdapter.flowPublisherToFlux(mirroringMessageStream.getMessages())
.doOnNext(entry -> {
if (entry instanceof MessageStreamEntry.Envelope(MessageProtos.Envelope message)) {
mirroringMessageStream.acknowledgeMessage(UUIDUtil.fromByteString(message.getServerGuid()),
message.getServerTimestamp());
}
})
.takeUntil(entry -> entry instanceof MessageStreamEntry.QueueEmpty)
.then()
.block());
verify(foundationDbMessageStream, never()).acknowledgeMessage(any(), anyLong());
}
private static MessageStreamEntry.Envelope generateEnvelopeEntry(final UUIDVersion uuidVersion) {
final UUID messageGuid = switch (uuidVersion) {
case V4 -> UUID.randomUUID();
case V8 -> MessageGuidUtil.generateRandomV8UUID();
};
return new MessageStreamEntry.Envelope(MessageProtos.Envelope.newBuilder()
.setServerGuid(UUIDUtil.toByteString(messageGuid))
.setServerTimestamp(serialMessageTimestamp++)
.build());
}
}

View File

@ -0,0 +1,68 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.util.List;
import java.util.UUID;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.foundationdb.FoundationDbMessageStore;
class DeletionMirroringRedisDynamoDbMessageStreamTest {
private FoundationDbMessageStore foundationDbMessageStore;
private ExperimentEnrollmentManager experimentEnrollmentManager;
private DeletionMirroringRedisDynamoDbMessageStream deletionMirroringRedisDynamoDbMessageStream;
private static final AciServiceIdentifier ACCOUNT_IDENTIFIER = new AciServiceIdentifier(UUID.randomUUID());
private static final byte DEVICE_ID = Device.PRIMARY_ID;
@BeforeEach
void setUp() {
foundationDbMessageStore = mock(FoundationDbMessageStore.class);
experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class);
deletionMirroringRedisDynamoDbMessageStream = new DeletionMirroringRedisDynamoDbMessageStream(
mock(RedisDynamoDbMessageStream.class),
foundationDbMessageStore,
experimentEnrollmentManager,
ACCOUNT_IDENTIFIER.uuid(),
DEVICE_ID);
}
@ParameterizedTest
@MethodSource
void acknowledgeMessage(final boolean enrolled, final UUID messageGuid, final boolean expectFoundationDbDeletion) {
when(experimentEnrollmentManager.isEnrolled(any(UUID.class), eq(MessagesManager.MIRROR_DELETIONS_EXPERIMENT_NAME)))
.thenReturn(enrolled);
deletionMirroringRedisDynamoDbMessageStream.acknowledgeMessage(messageGuid, System.currentTimeMillis());
verify(foundationDbMessageStore, times(expectFoundationDbDeletion ? 1 : 0))
.delete(ACCOUNT_IDENTIFIER, DEVICE_ID, messageGuid);
}
private static List<Arguments> acknowledgeMessage() {
return List.of(
Arguments.argumentSet("Not enrolled, v4 UUID", false, UUID.randomUUID(), false),
Arguments.argumentSet("Not enrolled, v8 UUID", false, MessageGuidUtil.generateRandomV8UUID(), false),
Arguments.argumentSet("Enrolled, v4 UUID", true, UUID.randomUUID(), false),
Arguments.argumentSet("Enrolled, v8 UUID", true, MessageGuidUtil.generateRandomV8UUID(), true)
);
}
}

View File

@ -452,6 +452,36 @@ class FoundationDbMessageStoreTest {
Map.of(generateRandomAciForShard(0), Collections.emptyMap())));
}
@Test
void delete() {
final AciServiceIdentifier deletedMessageAci = new AciServiceIdentifier(UUID.randomUUID());
final byte deletedMessageDeviceId = Device.PRIMARY_ID;
final MessageProtos.Envelope deletedMessage = generateRandomMessage(false);
final AciServiceIdentifier retainedMessageAci = new AciServiceIdentifier(UUID.randomUUID());
final byte retainedMessageDeviceId = Device.PRIMARY_ID;
final MessageProtos.Envelope retainedMessage = generateRandomMessage(false);
final UUID deletedMessageGuid =
foundationDbMessageStore.insert(deletedMessageAci, Map.of(deletedMessageDeviceId, deletedMessage)).join()
.get(deletedMessageDeviceId).messageGuid().orElseThrow();
foundationDbMessageStore.insert(retainedMessageAci, Map.of(retainedMessageDeviceId, retainedMessage)).join();
assertArrayEquals(deletedMessage.toByteArray(),
getItemsInDeviceQueue(deletedMessageAci, deletedMessageDeviceId).getFirst().getValue());
assertArrayEquals(retainedMessage.toByteArray(),
getItemsInDeviceQueue(retainedMessageAci, retainedMessageDeviceId).getFirst().getValue());
foundationDbMessageStore.delete(deletedMessageAci, deletedMessageDeviceId, deletedMessageGuid).join();
assertTrue(getItemsInDeviceQueue(deletedMessageAci, deletedMessageDeviceId).isEmpty());
assertArrayEquals(retainedMessage.toByteArray(),
getItemsInDeviceQueue(retainedMessageAci, retainedMessageDeviceId).getFirst().getValue());
}
@Test
void clearAllForAccount() {
final AciServiceIdentifier deletedAccountIdentifier = new AciServiceIdentifier(UUID.randomUUID());
@ -523,9 +553,10 @@ class FoundationDbMessageStoreTest {
.orElseThrow())
.toList();
final MessageStream messageStream =
foundationDbMessageStore.getMessages(aci, Device.PRIMARY_ID, batchSize, numMessages, Util.NOOP);
final Device device = new Device();
device.setId(Device.PRIMARY_ID);
final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device, batchSize, numMessages,
Util.NOOP);
final List<MessageStreamEntry> retrievedEntries = new ArrayList<>();
StepVerifier.create(JdkFlowAdapter.flowPublisherToFlux(messageStream.getMessages()))
.recordWith(() -> retrievedEntries)
@ -561,6 +592,9 @@ class FoundationDbMessageStoreTest {
final AciServiceIdentifier aci = new AciServiceIdentifier(UUID.randomUUID());
final byte deviceId = Device.PRIMARY_ID;
final Device device = new Device();
device.setId(deviceId);
final int messagesPerBatch = 8;
final AtomicLong serialTimestamp = new AtomicLong();
@ -574,7 +608,7 @@ class FoundationDbMessageStoreTest {
final List<MessageProtos.Envelope> liveDefaultEpochMessages = new ArrayList<>();
final List<MessageProtos.Envelope> liveFutureEpochMessages = new ArrayList<>();
final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, deviceId);
final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device);
final List<MessageStreamEntry> retrievedEntries = new ArrayList<>();
final CountDownLatch queueEmptyLatch = new CountDownLatch(1);
@ -655,7 +689,9 @@ class FoundationDbMessageStoreTest {
}
});
final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, Device.PRIMARY_ID);
final Device device = new Device();
device.setId(Device.PRIMARY_ID);
final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device);
StepVerifier.create(JdkFlowAdapter.flowPublisherToFlux(messageStream.getMessages()))
.expectNext(new MessageStreamEntry.Envelope(message1
.toBuilder()
@ -699,7 +735,9 @@ class FoundationDbMessageStoreTest {
}
});
final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, Device.PRIMARY_ID);
final Device device = new Device();
device.setId(Device.PRIMARY_ID);
final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device);
// Initially only request a single message, then give the go ahead to the publisher. This verifies that we get the
// queue empty signal when we consume the initial batch of messages even though the publisher keeps publishing in
// the meantime.
@ -730,8 +768,10 @@ class FoundationDbMessageStoreTest {
.orElseThrow())
.toList();
final Device device = new Device();
device.setId(Device.PRIMARY_ID);
final CountDownLatch latch = new CountDownLatch(1);
final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, Device.PRIMARY_ID,
final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device,
FoundationDbMessageStream.DEFAULT_MAX_MESSAGES_PER_SCAN,
FoundationDbMessageStream.DEFAULT_MAX_UNACKNOWLEDGED_MESSAGES,
latch::countDown);
@ -792,7 +832,12 @@ class FoundationDbMessageStoreTest {
void acknowledgeMessagesEpochChange() throws InterruptedException {
final AciServiceIdentifier aci = new AciServiceIdentifier(UUID.randomUUID());
final byte deviceId = Device.PRIMARY_ID;
final Device device = new Device();
device.setId(deviceId);
final int messagesPerBatch = 8;
final AtomicLong serialTimestamp = new AtomicLong();
generateAndInsertMessages(aci, deviceId, DEFAULT_EPOCH, messagesPerBatch, serialTimestamp::getAndIncrement);
@ -802,7 +847,7 @@ class FoundationDbMessageStoreTest {
final CountDownLatch cleanupLatch = new CountDownLatch(1);
final MessageStream messageStream = foundationDbMessageStore.getMessages(aci,
deviceId,
device,
FoundationDbMessageStream.DEFAULT_MAX_MESSAGES_PER_SCAN,
FoundationDbMessageStream.DEFAULT_MAX_UNACKNOWLEDGED_MESSAGES,
cleanupLatch::countDown);
@ -825,7 +870,7 @@ class FoundationDbMessageStoreTest {
}
{
final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, deviceId);
final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device);
final List<MessageStreamEntry> retrievedEntries = new ArrayList<>();
StepVerifier.create(JdkFlowAdapter.flowPublisherToFlux(messageStream.getMessages()))
@ -857,7 +902,9 @@ class FoundationDbMessageStoreTest {
foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, generateRandomMessage(false))).join();
}
final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, Device.PRIMARY_ID,
final Device device = new Device();
device.setId(Device.PRIMARY_ID);
final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device,
FoundationDbMessageStream.DEFAULT_MAX_MESSAGES_PER_SCAN,
maxUnacknowledgedMessages,
Util.NOOP);
@ -902,7 +949,10 @@ class FoundationDbMessageStoreTest {
}
});
final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, Device.PRIMARY_ID,
final Device device = new Device();
device.setId(Device.PRIMARY_ID);
final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device,
FoundationDbMessageStream.DEFAULT_MAX_MESSAGES_PER_SCAN,
FoundationDbMessageStream.DEFAULT_MAX_UNACKNOWLEDGED_MESSAGES,
cleanUpLatch::countDown);