diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AcknowledgementMirroringMessageStream.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AcknowledgementMirroringMessageStream.java new file mode 100644 index 000000000..3cc75781f --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AcknowledgementMirroringMessageStream.java @@ -0,0 +1,155 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import java.util.HashSet; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Flow; +import java.util.function.BiConsumer; +import com.google.common.annotations.VisibleForTesting; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Metrics; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; +import org.whispersystems.textsecuregcm.storage.foundationdb.FoundationDbMessageStream; +import org.whispersystems.textsecuregcm.util.UUIDUtil; +import reactor.adapter.JdkFlowAdapter; +import reactor.core.publisher.BaseSubscriber; + +/// A temporary message stream that can mirror message acknowledgements (deletion requests) to FoundationDB +public class AcknowledgementMirroringMessageStream implements MessageStream { + + private final RedisDynamoDbMessageStream redisDynamoDbMessageStream; + private final FoundationDbMessageStream foundationDbMessageStream; + + private final Set messagesPendingAcknowledgement = new HashSet<>(); + + private static final int HIGH_PENDING_ACKNOWLEDGEMENT_COUNT_WARNING_THRESHOLD = 1024; + + @VisibleForTesting + static final int MAX_PENDING_ACKNOWLEDGEMENTS = 8192; + + @VisibleForTesting + static final int FOUNDATIONDB_REQUEST_SIZE = 100; + + private static final Counter HIGH_PENDING_ACKNOWLEDGEMENT_COUNT_WARNING_COUNTER = + Metrics.counter(MetricsUtil.name(AcknowledgementMirroringMessageStream.class, "highPendingAcknowledgementCount")); + + private static final Counter MAX_PENDING_ACKNOWLEDGEMENT_LIMIT_BREACHED_COUNTER = + Metrics.counter(MetricsUtil.name(AcknowledgementMirroringMessageStream.class, "maxPendingAcknowledgementLimitBreached")); + + private static class FoundationDbSubscriber extends BaseSubscriber { + + private final BiConsumer mirroredMessageHandler; + + private long mirroredEntriesDelivered = 0; + private long requested = 0; + + private FoundationDbSubscriber(final BiConsumer mirroredMessageHandler) { + this.mirroredMessageHandler = mirroredMessageHandler; + } + + public void handleRedisDynamoDbMessageStreamEntry(final MessageStreamEntry messageStreamEntry) { + final boolean isMirroredMessage = switch (messageStreamEntry) { + case MessageStreamEntry.Envelope envelopeEntry -> + UUIDUtil.fromByteString(envelopeEntry.message().getServerGuid()).version() == 8; + + case MessageStreamEntry.QueueEmpty _ -> false; + }; + + if (!isMirroredMessage) { + return; + } + + final boolean requestMore; + + synchronized (this) { + requestMore = ++mirroredEntriesDelivered > requested; + + if (requestMore) { + requested += FOUNDATIONDB_REQUEST_SIZE; + } + } + + if (requestMore) { + request(FOUNDATIONDB_REQUEST_SIZE); + } + } + + @Override + protected void hookOnNext(final MessageStreamEntry.Envelope envelope) { + mirroredMessageHandler.accept( + UUIDUtil.fromByteString(envelope.message().getServerGuid()), envelope.message().getServerTimestamp()); + } + } + + public AcknowledgementMirroringMessageStream(final RedisDynamoDbMessageStream redisDynamoDbMessageStream, + final FoundationDbMessageStream foundationDbMessageStream) { + + this.redisDynamoDbMessageStream = redisDynamoDbMessageStream; + this.foundationDbMessageStream = foundationDbMessageStream; + } + + @Override + public Flow.Publisher getMessages() { + final FoundationDbSubscriber subscriber = new FoundationDbSubscriber(this::handleMirroredMessageAcknowledged); + + JdkFlowAdapter.flowPublisherToFlux(foundationDbMessageStream.getMessages()) + .filter(messageStreamEntry -> messageStreamEntry instanceof MessageStreamEntry.Envelope) + .cast(MessageStreamEntry.Envelope.class) + .subscribe(subscriber); + + return JdkFlowAdapter.publisherToFlowPublisher( + JdkFlowAdapter.flowPublisherToFlux(redisDynamoDbMessageStream.getMessages()) + // Mirror demand and termination signals to the FoundationDB message stream + .doOnNext(subscriber::handleRedisDynamoDbMessageStreamEntry) + .doOnComplete(subscriber::dispose)); + } + + @Override + public CompletableFuture acknowledgeMessage(final UUID messageGuid, final long serverTimestamp) { + // All messages stored in FoundationDB use version 8 UUIDs; if a message has a version 4 UUID, then it only exists + // in Redis/DynamoDB + if (messageGuid.version() == 8) { + handleMirroredMessageAcknowledged(messageGuid, serverTimestamp); + } + + return redisDynamoDbMessageStream.acknowledgeMessage(messageGuid, serverTimestamp); + } + + private void handleMirroredMessageAcknowledged(final UUID messageGuid, final long serverTimestamp) { + // Before we can acknowledge and delete the message in FoundationDB, two things need to happen: + // + // 1. The client needs to acknowledge the message on the Redis/DDB stream. + // 2. The message needs to pass through the FoundationDB message stream pipeline because it needs to be registered + // in the acknowledged message buffer machinery before it can be deleted. + // + // We can't guarantee the order in which these two things happen, so we use a synchronized set to coordinate and + // determine when both have happened. + synchronized (messagesPendingAcknowledgement) { + if (messagesPendingAcknowledgement.remove(messageGuid)) { + // One of the two streams had already acknowledged this message, and this is the second acknowledgement. Now we + // can pass the acknowledgement along to FoundationDB. + foundationDbMessageStream.acknowledgeMessage(messageGuid, serverTimestamp); + } else { + // Either this message came from FoundationDB and got auto-acknowledged before getting explicitly acknowledged + // on the Redis/DynamoDB stream or the message got explicitly acknowledged on the Redis/DynamoDB stream before + // passing through the FoundationDB stream. Either way, wait for both streams to have done their part. + if (messagesPendingAcknowledgement.add(messageGuid)) { + if (messagesPendingAcknowledgement.size() == HIGH_PENDING_ACKNOWLEDGEMENT_COUNT_WARNING_THRESHOLD) { + HIGH_PENDING_ACKNOWLEDGEMENT_COUNT_WARNING_COUNTER.increment(); + } + + if (messagesPendingAcknowledgement.size() > MAX_PENDING_ACKNOWLEDGEMENTS) { + MAX_PENDING_ACKNOWLEDGEMENT_LIMIT_BREACHED_COUNTER.increment(); + throw new IllegalStateException("Too many pending acknowledgements"); + } + } + } + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/DeletionMirroringRedisDynamoDbMessageStream.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/DeletionMirroringRedisDynamoDbMessageStream.java deleted file mode 100644 index 08e3afbef..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/DeletionMirroringRedisDynamoDbMessageStream.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright 2026 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.storage; - -import java.util.UUID; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.Flow; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; -import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; -import org.whispersystems.textsecuregcm.storage.foundationdb.FoundationDbMessageStore; - -/// A temporary message stream that can mirror message acknowledgements (deletion requests) to FoundationDB -public class DeletionMirroringRedisDynamoDbMessageStream implements MessageStream { - - private final RedisDynamoDbMessageStream redisDynamoDbMessageStream; - private final FoundationDbMessageStore foundationDbMessageStore; - private final ExperimentEnrollmentManager experimentEnrollmentManager; - - private final AciServiceIdentifier accountIdentifier; - private final byte deviceId; - - private static final Logger logger = LoggerFactory.getLogger(DeletionMirroringRedisDynamoDbMessageStream.class); - - public DeletionMirroringRedisDynamoDbMessageStream(final RedisDynamoDbMessageStream redisDynamoDbMessageStream, - final FoundationDbMessageStore foundationDbMessageStore, - final ExperimentEnrollmentManager experimentEnrollmentManager, - final UUID accountIdentifier, - final byte deviceId) { - - this.redisDynamoDbMessageStream = redisDynamoDbMessageStream; - this.foundationDbMessageStore = foundationDbMessageStore; - this.experimentEnrollmentManager = experimentEnrollmentManager; - - this.accountIdentifier = new AciServiceIdentifier(accountIdentifier); - this.deviceId = deviceId; - } - - @Override - public Flow.Publisher getMessages() { - return redisDynamoDbMessageStream.getMessages(); - } - - @Override - public CompletableFuture acknowledgeMessage(final UUID messageGuid, final long serverTimestamp) { - // All messages stored in FoundationDB use version 8 UUIDs; if a message has a version 4 UUID, then it only exists - // in Redis/DynamoDB - if (messageGuid.version() == 8 && - experimentEnrollmentManager.isEnrolled(accountIdentifier.uuid(), MessagesManager.MIRROR_DELETIONS_EXPERIMENT_NAME)) { - - foundationDbMessageStore.delete(accountIdentifier, deviceId, messageGuid) - .whenComplete((_, throwable) -> { - if (throwable != null) { - logger.warn("Failed to delete message {}/{}/{} from FoundationDb", accountIdentifier.uuid(), deviceId, - messageGuid, throwable); - } - }); - } - - return redisDynamoDbMessageStream.acknowledgeMessage(messageGuid, serverTimestamp); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java index 9a49cc0a5..939e54d06 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java @@ -38,6 +38,7 @@ import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.push.RedisMessageAvailabilityManager; import org.whispersystems.textsecuregcm.storage.foundationdb.FoundationDbMessageStore; +import org.whispersystems.textsecuregcm.storage.foundationdb.FoundationDbMessageStream; import org.whispersystems.textsecuregcm.util.UUIDUtil; import reactor.core.observability.micrometer.Micrometer; import reactor.core.publisher.Flux; @@ -279,12 +280,15 @@ public class MessagesManager { } public MessageStream getMessages(final UUID destinationUuid, final Device destinationDevice) { - return new DeletionMirroringRedisDynamoDbMessageStream( - new RedisDynamoDbMessageStream(messagesDynamoDb, messagesCache, redisMessageAvailabilityManager, destinationUuid, destinationDevice), - foundationDbMessageStore, - experimentEnrollmentManager, - destinationUuid, - destinationDevice.getId()); + final RedisDynamoDbMessageStream redisDynamoDbMessageStream = + new RedisDynamoDbMessageStream(messagesDynamoDb, messagesCache, redisMessageAvailabilityManager, + destinationUuid, destinationDevice); + + return experimentEnrollmentManager.isEnrolled(destinationUuid, MIRROR_DELETIONS_EXPERIMENT_NAME) + ? new AcknowledgementMirroringMessageStream( + redisDynamoDbMessageStream, + foundationDbMessageStore.getMessages(new AciServiceIdentifier(destinationUuid), destinationDevice.getId())) + : redisDynamoDbMessageStream; } Publisher getMessagesForDevice(final UUID destinationUuid, final Device destinationDevice) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java index d59cf7949..63e5aa7c5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java @@ -9,6 +9,9 @@ import com.apple.foundationdb.tuple.Versionstamp; import com.google.common.annotations.VisibleForTesting; import com.google.common.hash.Hashing; import io.dropwizard.util.DataSize; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Timer; import java.time.Clock; import java.time.Duration; import java.util.ArrayList; @@ -22,14 +25,9 @@ import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.IntStream; -import io.micrometer.core.instrument.Counter; -import io.micrometer.core.instrument.Metrics; -import io.micrometer.core.instrument.Timer; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; -import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.storage.MessageStream; import org.whispersystems.textsecuregcm.util.Conversions; import org.whispersystems.textsecuregcm.util.Util; @@ -71,12 +69,6 @@ public class FoundationDbMessageStore { private static final Timer INSERT_MESSAGE_BATCH_TIMER = Metrics.timer(MetricsUtil.name(FoundationDbMessageStore.class, "insertMessageBatchTimer")); - private static final Counter DELETE_MESSAGE_COUNTER = - Metrics.counter(MetricsUtil.name(FoundationDbMessageStore.class, "deleteMessage")); - - private static final Timer DELETE_MESSAGE_TIMER = - Metrics.timer(MetricsUtil.name(FoundationDbMessageStore.class, "deleteMessageTimer")); - /// Result of inserting a message for a particular device /// /// @param versionstamp the versionstamp of the transaction in which this device's message was inserted, empty @@ -324,27 +316,6 @@ public class FoundationDbMessageStore { }); } - // Note that this method is intended only for initial migration support; in general, callers should clear messages - // by acknowledging messages via a `FoundationDbMessageStream`. - public CompletableFuture delete(final AciServiceIdentifier aci, final byte deviceId, final UUID messageGuid) { - return delete(aci, deviceId, versionstampUUIDCipher.decryptVersionstamp(messageGuid, aci.uuid(), deviceId)); - } - - private CompletableFuture delete(final AciServiceIdentifier aci, final byte deviceId, final Versionstamp versionstamp) { - final Timer.Sample sample = Timer.start(); - - final byte[] messageKey = getDeviceQueueSubspace(aci, deviceId).pack(Tuple.from(versionstamp)); - - return databasesByEpoch[getConfigurationEpoch(versionstamp)][getShardId(versionstamp)].runAsync(transaction -> { - transaction.clear(messageKey); - return CompletableFuture.completedFuture(null); - }) - .thenRun(() -> { - sample.stop(DELETE_MESSAGE_TIMER); - DELETE_MESSAGE_COUNTER.increment(); - }); - } - public void clearAll(final AciServiceIdentifier aci) { doForAllDatabasesWithMessages(aci, database -> database.run(transaction -> { transaction.clear(getAccountSubspace(aci).range()); @@ -367,14 +338,14 @@ public class FoundationDbMessageStore { .forEach(action); } - public MessageStream getMessages(final AciServiceIdentifier aci, final Device destinationDevice) { - return getMessages(aci, destinationDevice, FoundationDbMessageStream.DEFAULT_MAX_MESSAGES_PER_SCAN, + public FoundationDbMessageStream getMessages(final AciServiceIdentifier aci, final byte deviceId) { + return getMessages(aci, deviceId, FoundationDbMessageStream.DEFAULT_MAX_MESSAGES_PER_SCAN, FoundationDbMessageStream.DEFAULT_MAX_UNACKNOWLEDGED_MESSAGES, Util.NOOP); } @VisibleForTesting - MessageStream getMessages(final AciServiceIdentifier aci, - final Device destinationDevice, + FoundationDbMessageStream getMessages(final AciServiceIdentifier aci, + final byte deviceId, final int maxMessagesPerScan, final int maxUnacknowledgedMessages, final Runnable doAfterCleanup) { @@ -388,10 +359,10 @@ public class FoundationDbMessageStore { : null; } - return new FoundationDbMessageStream(getDeviceQueueSubspace(aci, destinationDevice.getId()), + return new FoundationDbMessageStream(getDeviceQueueSubspace(aci, deviceId), getMessagesAvailableWatchKey(aci), databasesForQueueByEpoch, - new MessageGuidCodec(aci.uuid(), destinationDevice.getId(), versionstampUUIDCipher), + new MessageGuidCodec(aci.uuid(), deviceId, versionstampUUIDCipher), maxMessagesPerScan, maxUnacknowledgedMessages, doAfterCleanup); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStream.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStream.java index 2a713a001..c87d6b2b3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStream.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStream.java @@ -51,8 +51,15 @@ public class FoundationDbMessageStream implements MessageStream { private final Map acknowledgedMessageBuffersByDatabase; - private final Counter staleEphemeralMessagesCounter = Metrics.counter( - name(FoundationDbMessageStream.class, "staleEphemeralMessages")); + private final Counter messageReadCounter = + Metrics.counter(name(FoundationDbMessageStream.class, "messagesRead")); + + private final Counter messageAcknowledgedCounter = + Metrics.counter(name(FoundationDbMessageStream.class, "messagesAcknowledged")); + + private final Counter staleEphemeralMessagesCounter = + Metrics.counter(name(FoundationDbMessageStream.class, "staleEphemeralMessages")); + static final int DEFAULT_MAX_MESSAGES_PER_SCAN = 1024; @VisibleForTesting static final int DEFAULT_MAX_UNACKNOWLEDGED_MESSAGES = 16_384; @@ -116,6 +123,11 @@ public class FoundationDbMessageStream implements MessageStream { sink.next(messageStreamEntry); }) .map(fdbMessageStreamEntry -> fdbMessageStreamEntry.toMessageStreamEntry(messageGuidCodec)) + .doOnNext(messageStreamEntry -> { + if (messageStreamEntry instanceof MessageStreamEntry.Envelope) { + messageReadCounter.increment(); + } + }) .doFinally(_ -> flushAllAcknowledgedMessages().thenRun(doAfterCleanup)); } @@ -225,6 +237,8 @@ public class FoundationDbMessageStream implements MessageStream { final Versionstamp versionstamp = messageGuidCodec.decodeMessageGuid(messageGuid); getAcknowledgedMessageBuffer(versionstamp).acknowledgeMessage(versionstamp); + messageAcknowledgedCounter.increment(); + return CompletableFuture.completedFuture(null); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AcknowledgementMirroringMessageStreamTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AcknowledgementMirroringMessageStreamTest.java new file mode 100644 index 000000000..4f3c35fd1 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AcknowledgementMirroringMessageStreamTest.java @@ -0,0 +1,146 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; +import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.storage.foundationdb.FoundationDbMessageStream; +import org.whispersystems.textsecuregcm.util.UUIDUtil; +import reactor.adapter.JdkFlowAdapter; +import reactor.core.publisher.Flux; + +class AcknowledgementMirroringMessageStreamTest { + + private RedisDynamoDbMessageStream redisDynamoDbMessageStream; + private FoundationDbMessageStream foundationDbMessageStream; + + private AcknowledgementMirroringMessageStream mirroringMessageStream; + + private static long serialMessageTimestamp = 0; + + private enum UUIDVersion { + V4, + V8 + } + + @BeforeEach + void setUp() { + redisDynamoDbMessageStream = mock(RedisDynamoDbMessageStream.class); + + when(redisDynamoDbMessageStream.acknowledgeMessage(any(), anyLong())) + .thenReturn(CompletableFuture.completedFuture(null)); + + foundationDbMessageStream = mock(FoundationDbMessageStream.class); + + when(foundationDbMessageStream.acknowledgeMessage(any(), anyLong())) + .thenReturn(CompletableFuture.completedFuture(null)); + + mirroringMessageStream = new AcknowledgementMirroringMessageStream( + redisDynamoDbMessageStream, + foundationDbMessageStream); + } + + @ParameterizedTest + @EnumSource(UUIDVersion.class) + void acknowledgeMessages(final UUIDVersion uuidVersion) { + final List entries = new ArrayList<>(); + { + for (int i = 0; i < AcknowledgementMirroringMessageStream.FOUNDATIONDB_REQUEST_SIZE + 1; i++) { + entries.add(generateEnvelopeEntry(uuidVersion)); + } + + entries.add(new MessageStreamEntry.QueueEmpty()); + } + + when(redisDynamoDbMessageStream.getMessages()) + .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.fromIterable(entries))); + + when(foundationDbMessageStream.getMessages()) + .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.fromIterable( + uuidVersion == UUIDVersion.V8 ? entries : Collections.emptyList()))); + + JdkFlowAdapter.flowPublisherToFlux(mirroringMessageStream.getMessages()) + .doOnNext(entry -> { + if (entry instanceof MessageStreamEntry.Envelope(MessageProtos.Envelope message)) { + mirroringMessageStream.acknowledgeMessage(UUIDUtil.fromByteString(message.getServerGuid()), + message.getServerTimestamp()); + } + }) + .takeUntil(entry -> entry instanceof MessageStreamEntry.QueueEmpty) + .then() + .block(); + + if (uuidVersion == UUIDVersion.V8) { + entries.stream() + .filter(streamEntry -> streamEntry instanceof MessageStreamEntry.Envelope) + .map(streamEntry -> ((MessageStreamEntry.Envelope) streamEntry).message()) + .forEach(envelope -> verify(foundationDbMessageStream).acknowledgeMessage( + UUIDUtil.fromByteString(envelope.getServerGuid()), envelope.getServerTimestamp())); + } else { + verify(foundationDbMessageStream, never()).acknowledgeMessage(any(), anyLong()); + } + } + + @Test + void acknowledgeMessagesOverflow() { + final List entries = new ArrayList<>(); + { + for (int i = 0; i < AcknowledgementMirroringMessageStream.MAX_PENDING_ACKNOWLEDGEMENTS + 1; i++) { + entries.add(generateEnvelopeEntry(UUIDVersion.V8)); + } + + entries.add(new MessageStreamEntry.QueueEmpty()); + } + + when(redisDynamoDbMessageStream.getMessages()) + .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.fromIterable(entries))); + + when(foundationDbMessageStream.getMessages()) + .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.empty())); + + assertThrows(IllegalStateException.class, + () -> JdkFlowAdapter.flowPublisherToFlux(mirroringMessageStream.getMessages()) + .doOnNext(entry -> { + if (entry instanceof MessageStreamEntry.Envelope(MessageProtos.Envelope message)) { + mirroringMessageStream.acknowledgeMessage(UUIDUtil.fromByteString(message.getServerGuid()), + message.getServerTimestamp()); + } + }) + .takeUntil(entry -> entry instanceof MessageStreamEntry.QueueEmpty) + .then() + .block()); + + verify(foundationDbMessageStream, never()).acknowledgeMessage(any(), anyLong()); + } + + private static MessageStreamEntry.Envelope generateEnvelopeEntry(final UUIDVersion uuidVersion) { + final UUID messageGuid = switch (uuidVersion) { + case V4 -> UUID.randomUUID(); + case V8 -> MessageGuidUtil.generateRandomV8UUID(); + }; + + return new MessageStreamEntry.Envelope(MessageProtos.Envelope.newBuilder() + .setServerGuid(UUIDUtil.toByteString(messageGuid)) + .setServerTimestamp(serialMessageTimestamp++) + .build()); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/DeletionMirroringRedisDynamoDbMessageStreamTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/DeletionMirroringRedisDynamoDbMessageStreamTest.java deleted file mode 100644 index 08414df6b..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/DeletionMirroringRedisDynamoDbMessageStreamTest.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright 2026 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.storage; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.util.List; -import java.util.UUID; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; -import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; -import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; -import org.whispersystems.textsecuregcm.storage.foundationdb.FoundationDbMessageStore; - -class DeletionMirroringRedisDynamoDbMessageStreamTest { - - private FoundationDbMessageStore foundationDbMessageStore; - private ExperimentEnrollmentManager experimentEnrollmentManager; - - private DeletionMirroringRedisDynamoDbMessageStream deletionMirroringRedisDynamoDbMessageStream; - - private static final AciServiceIdentifier ACCOUNT_IDENTIFIER = new AciServiceIdentifier(UUID.randomUUID()); - private static final byte DEVICE_ID = Device.PRIMARY_ID; - - @BeforeEach - void setUp() { - foundationDbMessageStore = mock(FoundationDbMessageStore.class); - experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class); - - deletionMirroringRedisDynamoDbMessageStream = new DeletionMirroringRedisDynamoDbMessageStream( - mock(RedisDynamoDbMessageStream.class), - foundationDbMessageStore, - experimentEnrollmentManager, - ACCOUNT_IDENTIFIER.uuid(), - DEVICE_ID); - } - - @ParameterizedTest - @MethodSource - void acknowledgeMessage(final boolean enrolled, final UUID messageGuid, final boolean expectFoundationDbDeletion) { - when(experimentEnrollmentManager.isEnrolled(any(UUID.class), eq(MessagesManager.MIRROR_DELETIONS_EXPERIMENT_NAME))) - .thenReturn(enrolled); - - deletionMirroringRedisDynamoDbMessageStream.acknowledgeMessage(messageGuid, System.currentTimeMillis()); - - verify(foundationDbMessageStore, times(expectFoundationDbDeletion ? 1 : 0)) - .delete(ACCOUNT_IDENTIFIER, DEVICE_ID, messageGuid); - } - - private static List acknowledgeMessage() { - return List.of( - Arguments.argumentSet("Not enrolled, v4 UUID", false, UUID.randomUUID(), false), - Arguments.argumentSet("Not enrolled, v8 UUID", false, MessageGuidUtil.generateRandomV8UUID(), false), - Arguments.argumentSet("Enrolled, v4 UUID", true, UUID.randomUUID(), false), - Arguments.argumentSet("Enrolled, v8 UUID", true, MessageGuidUtil.generateRandomV8UUID(), true) - ); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java index 1d7303ffe..212a5455e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java @@ -452,36 +452,6 @@ class FoundationDbMessageStoreTest { Map.of(generateRandomAciForShard(0), Collections.emptyMap()))); } - @Test - void delete() { - final AciServiceIdentifier deletedMessageAci = new AciServiceIdentifier(UUID.randomUUID()); - final byte deletedMessageDeviceId = Device.PRIMARY_ID; - final MessageProtos.Envelope deletedMessage = generateRandomMessage(false); - - final AciServiceIdentifier retainedMessageAci = new AciServiceIdentifier(UUID.randomUUID()); - final byte retainedMessageDeviceId = Device.PRIMARY_ID; - final MessageProtos.Envelope retainedMessage = generateRandomMessage(false); - - final UUID deletedMessageGuid = - foundationDbMessageStore.insert(deletedMessageAci, Map.of(deletedMessageDeviceId, deletedMessage)).join() - .get(deletedMessageDeviceId).messageGuid().orElseThrow(); - - foundationDbMessageStore.insert(retainedMessageAci, Map.of(retainedMessageDeviceId, retainedMessage)).join(); - - assertArrayEquals(deletedMessage.toByteArray(), - getItemsInDeviceQueue(deletedMessageAci, deletedMessageDeviceId).getFirst().getValue()); - - assertArrayEquals(retainedMessage.toByteArray(), - getItemsInDeviceQueue(retainedMessageAci, retainedMessageDeviceId).getFirst().getValue()); - - foundationDbMessageStore.delete(deletedMessageAci, deletedMessageDeviceId, deletedMessageGuid).join(); - - assertTrue(getItemsInDeviceQueue(deletedMessageAci, deletedMessageDeviceId).isEmpty()); - - assertArrayEquals(retainedMessage.toByteArray(), - getItemsInDeviceQueue(retainedMessageAci, retainedMessageDeviceId).getFirst().getValue()); - } - @Test void clearAllForAccount() { final AciServiceIdentifier deletedAccountIdentifier = new AciServiceIdentifier(UUID.randomUUID()); @@ -553,10 +523,9 @@ class FoundationDbMessageStoreTest { .orElseThrow()) .toList(); - final Device device = new Device(); - device.setId(Device.PRIMARY_ID); - final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device, batchSize, numMessages, - Util.NOOP); + final MessageStream messageStream = + foundationDbMessageStore.getMessages(aci, Device.PRIMARY_ID, batchSize, numMessages, Util.NOOP); + final List retrievedEntries = new ArrayList<>(); StepVerifier.create(JdkFlowAdapter.flowPublisherToFlux(messageStream.getMessages())) .recordWith(() -> retrievedEntries) @@ -592,9 +561,6 @@ class FoundationDbMessageStoreTest { final AciServiceIdentifier aci = new AciServiceIdentifier(UUID.randomUUID()); final byte deviceId = Device.PRIMARY_ID; - final Device device = new Device(); - device.setId(deviceId); - final int messagesPerBatch = 8; final AtomicLong serialTimestamp = new AtomicLong(); @@ -608,7 +574,7 @@ class FoundationDbMessageStoreTest { final List liveDefaultEpochMessages = new ArrayList<>(); final List liveFutureEpochMessages = new ArrayList<>(); - final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device); + final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, deviceId); final List retrievedEntries = new ArrayList<>(); final CountDownLatch queueEmptyLatch = new CountDownLatch(1); @@ -689,9 +655,7 @@ class FoundationDbMessageStoreTest { } }); - final Device device = new Device(); - device.setId(Device.PRIMARY_ID); - final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device); + final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, Device.PRIMARY_ID); StepVerifier.create(JdkFlowAdapter.flowPublisherToFlux(messageStream.getMessages())) .expectNext(new MessageStreamEntry.Envelope(message1 .toBuilder() @@ -735,9 +699,7 @@ class FoundationDbMessageStoreTest { } }); - final Device device = new Device(); - device.setId(Device.PRIMARY_ID); - final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device); + final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, Device.PRIMARY_ID); // Initially only request a single message, then give the go ahead to the publisher. This verifies that we get the // queue empty signal when we consume the initial batch of messages even though the publisher keeps publishing in // the meantime. @@ -768,10 +730,8 @@ class FoundationDbMessageStoreTest { .orElseThrow()) .toList(); - final Device device = new Device(); - device.setId(Device.PRIMARY_ID); final CountDownLatch latch = new CountDownLatch(1); - final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device, + final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, Device.PRIMARY_ID, FoundationDbMessageStream.DEFAULT_MAX_MESSAGES_PER_SCAN, FoundationDbMessageStream.DEFAULT_MAX_UNACKNOWLEDGED_MESSAGES, latch::countDown); @@ -832,12 +792,7 @@ class FoundationDbMessageStoreTest { void acknowledgeMessagesEpochChange() throws InterruptedException { final AciServiceIdentifier aci = new AciServiceIdentifier(UUID.randomUUID()); final byte deviceId = Device.PRIMARY_ID; - - final Device device = new Device(); - device.setId(deviceId); - final int messagesPerBatch = 8; - final AtomicLong serialTimestamp = new AtomicLong(); generateAndInsertMessages(aci, deviceId, DEFAULT_EPOCH, messagesPerBatch, serialTimestamp::getAndIncrement); @@ -847,7 +802,7 @@ class FoundationDbMessageStoreTest { final CountDownLatch cleanupLatch = new CountDownLatch(1); final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, - device, + deviceId, FoundationDbMessageStream.DEFAULT_MAX_MESSAGES_PER_SCAN, FoundationDbMessageStream.DEFAULT_MAX_UNACKNOWLEDGED_MESSAGES, cleanupLatch::countDown); @@ -870,7 +825,7 @@ class FoundationDbMessageStoreTest { } { - final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device); + final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, deviceId); final List retrievedEntries = new ArrayList<>(); StepVerifier.create(JdkFlowAdapter.flowPublisherToFlux(messageStream.getMessages())) @@ -902,9 +857,7 @@ class FoundationDbMessageStoreTest { foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, generateRandomMessage(false))).join(); } - final Device device = new Device(); - device.setId(Device.PRIMARY_ID); - final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device, + final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, Device.PRIMARY_ID, FoundationDbMessageStream.DEFAULT_MAX_MESSAGES_PER_SCAN, maxUnacknowledgedMessages, Util.NOOP); @@ -949,10 +902,7 @@ class FoundationDbMessageStoreTest { } }); - final Device device = new Device(); - device.setId(Device.PRIMARY_ID); - - final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device, + final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, Device.PRIMARY_ID, FoundationDbMessageStream.DEFAULT_MAX_MESSAGES_PER_SCAN, FoundationDbMessageStream.DEFAULT_MAX_UNACKNOWLEDGED_MESSAGES, cleanUpLatch::countDown);