From 28aefe0ebe627415ffee97181b707ca3a3c2a6d4 Mon Sep 17 00:00:00 2001 From: Jon Chambers <63609320+jon-signal@users.noreply.github.com> Date: Wed, 24 Jun 2026 12:58:20 -0400 Subject: [PATCH] Revert "Mirror message reads/acknowledgements via `FoundationDbMessageStream`" This reverts commit 4e52317e26a19a398378943460df4d3053e5966a. --- ...AcknowledgementMirroringMessageStream.java | 155 ------------------ ...onMirroringRedisDynamoDbMessageStream.java | 66 ++++++++ .../storage/MessagesManager.java | 16 +- .../FoundationDbMessageStore.java | 47 +++++- .../FoundationDbMessageStream.java | 18 +- ...owledgementMirroringMessageStreamTest.java | 146 ----------------- ...rroringRedisDynamoDbMessageStreamTest.java | 68 ++++++++ .../FoundationDbMessageStoreTest.java | 72 ++++++-- 8 files changed, 241 insertions(+), 347 deletions(-) delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/AcknowledgementMirroringMessageStream.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/DeletionMirroringRedisDynamoDbMessageStream.java delete mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/storage/AcknowledgementMirroringMessageStreamTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/storage/DeletionMirroringRedisDynamoDbMessageStreamTest.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AcknowledgementMirroringMessageStream.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AcknowledgementMirroringMessageStream.java deleted file mode 100644 index 3cc75781f..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AcknowledgementMirroringMessageStream.java +++ /dev/null @@ -1,155 +0,0 @@ -/* - * Copyright 2026 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.storage; - -import java.util.HashSet; -import java.util.Set; -import java.util.UUID; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.Flow; -import java.util.function.BiConsumer; -import com.google.common.annotations.VisibleForTesting; -import io.micrometer.core.instrument.Counter; -import io.micrometer.core.instrument.Metrics; -import org.whispersystems.textsecuregcm.metrics.MetricsUtil; -import org.whispersystems.textsecuregcm.storage.foundationdb.FoundationDbMessageStream; -import org.whispersystems.textsecuregcm.util.UUIDUtil; -import reactor.adapter.JdkFlowAdapter; -import reactor.core.publisher.BaseSubscriber; - -/// A temporary message stream that can mirror message acknowledgements (deletion requests) to FoundationDB -public class AcknowledgementMirroringMessageStream implements MessageStream { - - private final RedisDynamoDbMessageStream redisDynamoDbMessageStream; - private final FoundationDbMessageStream foundationDbMessageStream; - - private final Set messagesPendingAcknowledgement = new HashSet<>(); - - private static final int HIGH_PENDING_ACKNOWLEDGEMENT_COUNT_WARNING_THRESHOLD = 1024; - - @VisibleForTesting - static final int MAX_PENDING_ACKNOWLEDGEMENTS = 8192; - - @VisibleForTesting - static final int FOUNDATIONDB_REQUEST_SIZE = 100; - - private static final Counter HIGH_PENDING_ACKNOWLEDGEMENT_COUNT_WARNING_COUNTER = - Metrics.counter(MetricsUtil.name(AcknowledgementMirroringMessageStream.class, "highPendingAcknowledgementCount")); - - private static final Counter MAX_PENDING_ACKNOWLEDGEMENT_LIMIT_BREACHED_COUNTER = - Metrics.counter(MetricsUtil.name(AcknowledgementMirroringMessageStream.class, "maxPendingAcknowledgementLimitBreached")); - - private static class FoundationDbSubscriber extends BaseSubscriber { - - private final BiConsumer mirroredMessageHandler; - - private long mirroredEntriesDelivered = 0; - private long requested = 0; - - private FoundationDbSubscriber(final BiConsumer mirroredMessageHandler) { - this.mirroredMessageHandler = mirroredMessageHandler; - } - - public void handleRedisDynamoDbMessageStreamEntry(final MessageStreamEntry messageStreamEntry) { - final boolean isMirroredMessage = switch (messageStreamEntry) { - case MessageStreamEntry.Envelope envelopeEntry -> - UUIDUtil.fromByteString(envelopeEntry.message().getServerGuid()).version() == 8; - - case MessageStreamEntry.QueueEmpty _ -> false; - }; - - if (!isMirroredMessage) { - return; - } - - final boolean requestMore; - - synchronized (this) { - requestMore = ++mirroredEntriesDelivered > requested; - - if (requestMore) { - requested += FOUNDATIONDB_REQUEST_SIZE; - } - } - - if (requestMore) { - request(FOUNDATIONDB_REQUEST_SIZE); - } - } - - @Override - protected void hookOnNext(final MessageStreamEntry.Envelope envelope) { - mirroredMessageHandler.accept( - UUIDUtil.fromByteString(envelope.message().getServerGuid()), envelope.message().getServerTimestamp()); - } - } - - public AcknowledgementMirroringMessageStream(final RedisDynamoDbMessageStream redisDynamoDbMessageStream, - final FoundationDbMessageStream foundationDbMessageStream) { - - this.redisDynamoDbMessageStream = redisDynamoDbMessageStream; - this.foundationDbMessageStream = foundationDbMessageStream; - } - - @Override - public Flow.Publisher getMessages() { - final FoundationDbSubscriber subscriber = new FoundationDbSubscriber(this::handleMirroredMessageAcknowledged); - - JdkFlowAdapter.flowPublisherToFlux(foundationDbMessageStream.getMessages()) - .filter(messageStreamEntry -> messageStreamEntry instanceof MessageStreamEntry.Envelope) - .cast(MessageStreamEntry.Envelope.class) - .subscribe(subscriber); - - return JdkFlowAdapter.publisherToFlowPublisher( - JdkFlowAdapter.flowPublisherToFlux(redisDynamoDbMessageStream.getMessages()) - // Mirror demand and termination signals to the FoundationDB message stream - .doOnNext(subscriber::handleRedisDynamoDbMessageStreamEntry) - .doOnComplete(subscriber::dispose)); - } - - @Override - public CompletableFuture acknowledgeMessage(final UUID messageGuid, final long serverTimestamp) { - // All messages stored in FoundationDB use version 8 UUIDs; if a message has a version 4 UUID, then it only exists - // in Redis/DynamoDB - if (messageGuid.version() == 8) { - handleMirroredMessageAcknowledged(messageGuid, serverTimestamp); - } - - return redisDynamoDbMessageStream.acknowledgeMessage(messageGuid, serverTimestamp); - } - - private void handleMirroredMessageAcknowledged(final UUID messageGuid, final long serverTimestamp) { - // Before we can acknowledge and delete the message in FoundationDB, two things need to happen: - // - // 1. The client needs to acknowledge the message on the Redis/DDB stream. - // 2. The message needs to pass through the FoundationDB message stream pipeline because it needs to be registered - // in the acknowledged message buffer machinery before it can be deleted. - // - // We can't guarantee the order in which these two things happen, so we use a synchronized set to coordinate and - // determine when both have happened. - synchronized (messagesPendingAcknowledgement) { - if (messagesPendingAcknowledgement.remove(messageGuid)) { - // One of the two streams had already acknowledged this message, and this is the second acknowledgement. Now we - // can pass the acknowledgement along to FoundationDB. - foundationDbMessageStream.acknowledgeMessage(messageGuid, serverTimestamp); - } else { - // Either this message came from FoundationDB and got auto-acknowledged before getting explicitly acknowledged - // on the Redis/DynamoDB stream or the message got explicitly acknowledged on the Redis/DynamoDB stream before - // passing through the FoundationDB stream. Either way, wait for both streams to have done their part. - if (messagesPendingAcknowledgement.add(messageGuid)) { - if (messagesPendingAcknowledgement.size() == HIGH_PENDING_ACKNOWLEDGEMENT_COUNT_WARNING_THRESHOLD) { - HIGH_PENDING_ACKNOWLEDGEMENT_COUNT_WARNING_COUNTER.increment(); - } - - if (messagesPendingAcknowledgement.size() > MAX_PENDING_ACKNOWLEDGEMENTS) { - MAX_PENDING_ACKNOWLEDGEMENT_LIMIT_BREACHED_COUNTER.increment(); - throw new IllegalStateException("Too many pending acknowledgements"); - } - } - } - } - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/DeletionMirroringRedisDynamoDbMessageStream.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/DeletionMirroringRedisDynamoDbMessageStream.java new file mode 100644 index 000000000..08e3afbef --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/DeletionMirroringRedisDynamoDbMessageStream.java @@ -0,0 +1,66 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Flow; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.storage.foundationdb.FoundationDbMessageStore; + +/// A temporary message stream that can mirror message acknowledgements (deletion requests) to FoundationDB +public class DeletionMirroringRedisDynamoDbMessageStream implements MessageStream { + + private final RedisDynamoDbMessageStream redisDynamoDbMessageStream; + private final FoundationDbMessageStore foundationDbMessageStore; + private final ExperimentEnrollmentManager experimentEnrollmentManager; + + private final AciServiceIdentifier accountIdentifier; + private final byte deviceId; + + private static final Logger logger = LoggerFactory.getLogger(DeletionMirroringRedisDynamoDbMessageStream.class); + + public DeletionMirroringRedisDynamoDbMessageStream(final RedisDynamoDbMessageStream redisDynamoDbMessageStream, + final FoundationDbMessageStore foundationDbMessageStore, + final ExperimentEnrollmentManager experimentEnrollmentManager, + final UUID accountIdentifier, + final byte deviceId) { + + this.redisDynamoDbMessageStream = redisDynamoDbMessageStream; + this.foundationDbMessageStore = foundationDbMessageStore; + this.experimentEnrollmentManager = experimentEnrollmentManager; + + this.accountIdentifier = new AciServiceIdentifier(accountIdentifier); + this.deviceId = deviceId; + } + + @Override + public Flow.Publisher getMessages() { + return redisDynamoDbMessageStream.getMessages(); + } + + @Override + public CompletableFuture acknowledgeMessage(final UUID messageGuid, final long serverTimestamp) { + // All messages stored in FoundationDB use version 8 UUIDs; if a message has a version 4 UUID, then it only exists + // in Redis/DynamoDB + if (messageGuid.version() == 8 && + experimentEnrollmentManager.isEnrolled(accountIdentifier.uuid(), MessagesManager.MIRROR_DELETIONS_EXPERIMENT_NAME)) { + + foundationDbMessageStore.delete(accountIdentifier, deviceId, messageGuid) + .whenComplete((_, throwable) -> { + if (throwable != null) { + logger.warn("Failed to delete message {}/{}/{} from FoundationDb", accountIdentifier.uuid(), deviceId, + messageGuid, throwable); + } + }); + } + + return redisDynamoDbMessageStream.acknowledgeMessage(messageGuid, serverTimestamp); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java index 939e54d06..9a49cc0a5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java @@ -38,7 +38,6 @@ import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.push.RedisMessageAvailabilityManager; import org.whispersystems.textsecuregcm.storage.foundationdb.FoundationDbMessageStore; -import org.whispersystems.textsecuregcm.storage.foundationdb.FoundationDbMessageStream; import org.whispersystems.textsecuregcm.util.UUIDUtil; import reactor.core.observability.micrometer.Micrometer; import reactor.core.publisher.Flux; @@ -280,15 +279,12 @@ public class MessagesManager { } public MessageStream getMessages(final UUID destinationUuid, final Device destinationDevice) { - final RedisDynamoDbMessageStream redisDynamoDbMessageStream = - new RedisDynamoDbMessageStream(messagesDynamoDb, messagesCache, redisMessageAvailabilityManager, - destinationUuid, destinationDevice); - - return experimentEnrollmentManager.isEnrolled(destinationUuid, MIRROR_DELETIONS_EXPERIMENT_NAME) - ? new AcknowledgementMirroringMessageStream( - redisDynamoDbMessageStream, - foundationDbMessageStore.getMessages(new AciServiceIdentifier(destinationUuid), destinationDevice.getId())) - : redisDynamoDbMessageStream; + return new DeletionMirroringRedisDynamoDbMessageStream( + new RedisDynamoDbMessageStream(messagesDynamoDb, messagesCache, redisMessageAvailabilityManager, destinationUuid, destinationDevice), + foundationDbMessageStore, + experimentEnrollmentManager, + destinationUuid, + destinationDevice.getId()); } Publisher getMessagesForDevice(final UUID destinationUuid, final Device destinationDevice) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java index 63e5aa7c5..d59cf7949 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java @@ -9,9 +9,6 @@ import com.apple.foundationdb.tuple.Versionstamp; import com.google.common.annotations.VisibleForTesting; import com.google.common.hash.Hashing; import io.dropwizard.util.DataSize; -import io.micrometer.core.instrument.Counter; -import io.micrometer.core.instrument.Metrics; -import io.micrometer.core.instrument.Timer; import java.time.Clock; import java.time.Duration; import java.util.ArrayList; @@ -25,9 +22,14 @@ import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.IntStream; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Timer; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.MessageStream; import org.whispersystems.textsecuregcm.util.Conversions; import org.whispersystems.textsecuregcm.util.Util; @@ -69,6 +71,12 @@ public class FoundationDbMessageStore { private static final Timer INSERT_MESSAGE_BATCH_TIMER = Metrics.timer(MetricsUtil.name(FoundationDbMessageStore.class, "insertMessageBatchTimer")); + private static final Counter DELETE_MESSAGE_COUNTER = + Metrics.counter(MetricsUtil.name(FoundationDbMessageStore.class, "deleteMessage")); + + private static final Timer DELETE_MESSAGE_TIMER = + Metrics.timer(MetricsUtil.name(FoundationDbMessageStore.class, "deleteMessageTimer")); + /// Result of inserting a message for a particular device /// /// @param versionstamp the versionstamp of the transaction in which this device's message was inserted, empty @@ -316,6 +324,27 @@ public class FoundationDbMessageStore { }); } + // Note that this method is intended only for initial migration support; in general, callers should clear messages + // by acknowledging messages via a `FoundationDbMessageStream`. + public CompletableFuture delete(final AciServiceIdentifier aci, final byte deviceId, final UUID messageGuid) { + return delete(aci, deviceId, versionstampUUIDCipher.decryptVersionstamp(messageGuid, aci.uuid(), deviceId)); + } + + private CompletableFuture delete(final AciServiceIdentifier aci, final byte deviceId, final Versionstamp versionstamp) { + final Timer.Sample sample = Timer.start(); + + final byte[] messageKey = getDeviceQueueSubspace(aci, deviceId).pack(Tuple.from(versionstamp)); + + return databasesByEpoch[getConfigurationEpoch(versionstamp)][getShardId(versionstamp)].runAsync(transaction -> { + transaction.clear(messageKey); + return CompletableFuture.completedFuture(null); + }) + .thenRun(() -> { + sample.stop(DELETE_MESSAGE_TIMER); + DELETE_MESSAGE_COUNTER.increment(); + }); + } + public void clearAll(final AciServiceIdentifier aci) { doForAllDatabasesWithMessages(aci, database -> database.run(transaction -> { transaction.clear(getAccountSubspace(aci).range()); @@ -338,14 +367,14 @@ public class FoundationDbMessageStore { .forEach(action); } - public FoundationDbMessageStream getMessages(final AciServiceIdentifier aci, final byte deviceId) { - return getMessages(aci, deviceId, FoundationDbMessageStream.DEFAULT_MAX_MESSAGES_PER_SCAN, + public MessageStream getMessages(final AciServiceIdentifier aci, final Device destinationDevice) { + return getMessages(aci, destinationDevice, FoundationDbMessageStream.DEFAULT_MAX_MESSAGES_PER_SCAN, FoundationDbMessageStream.DEFAULT_MAX_UNACKNOWLEDGED_MESSAGES, Util.NOOP); } @VisibleForTesting - FoundationDbMessageStream getMessages(final AciServiceIdentifier aci, - final byte deviceId, + MessageStream getMessages(final AciServiceIdentifier aci, + final Device destinationDevice, final int maxMessagesPerScan, final int maxUnacknowledgedMessages, final Runnable doAfterCleanup) { @@ -359,10 +388,10 @@ public class FoundationDbMessageStore { : null; } - return new FoundationDbMessageStream(getDeviceQueueSubspace(aci, deviceId), + return new FoundationDbMessageStream(getDeviceQueueSubspace(aci, destinationDevice.getId()), getMessagesAvailableWatchKey(aci), databasesForQueueByEpoch, - new MessageGuidCodec(aci.uuid(), deviceId, versionstampUUIDCipher), + new MessageGuidCodec(aci.uuid(), destinationDevice.getId(), versionstampUUIDCipher), maxMessagesPerScan, maxUnacknowledgedMessages, doAfterCleanup); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStream.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStream.java index c87d6b2b3..2a713a001 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStream.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStream.java @@ -51,15 +51,8 @@ public class FoundationDbMessageStream implements MessageStream { private final Map acknowledgedMessageBuffersByDatabase; - private final Counter messageReadCounter = - Metrics.counter(name(FoundationDbMessageStream.class, "messagesRead")); - - private final Counter messageAcknowledgedCounter = - Metrics.counter(name(FoundationDbMessageStream.class, "messagesAcknowledged")); - - private final Counter staleEphemeralMessagesCounter = - Metrics.counter(name(FoundationDbMessageStream.class, "staleEphemeralMessages")); - + private final Counter staleEphemeralMessagesCounter = Metrics.counter( + name(FoundationDbMessageStream.class, "staleEphemeralMessages")); static final int DEFAULT_MAX_MESSAGES_PER_SCAN = 1024; @VisibleForTesting static final int DEFAULT_MAX_UNACKNOWLEDGED_MESSAGES = 16_384; @@ -123,11 +116,6 @@ public class FoundationDbMessageStream implements MessageStream { sink.next(messageStreamEntry); }) .map(fdbMessageStreamEntry -> fdbMessageStreamEntry.toMessageStreamEntry(messageGuidCodec)) - .doOnNext(messageStreamEntry -> { - if (messageStreamEntry instanceof MessageStreamEntry.Envelope) { - messageReadCounter.increment(); - } - }) .doFinally(_ -> flushAllAcknowledgedMessages().thenRun(doAfterCleanup)); } @@ -237,8 +225,6 @@ public class FoundationDbMessageStream implements MessageStream { final Versionstamp versionstamp = messageGuidCodec.decodeMessageGuid(messageGuid); getAcknowledgedMessageBuffer(versionstamp).acknowledgeMessage(versionstamp); - messageAcknowledgedCounter.increment(); - return CompletableFuture.completedFuture(null); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AcknowledgementMirroringMessageStreamTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AcknowledgementMirroringMessageStreamTest.java deleted file mode 100644 index 4f3c35fd1..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AcknowledgementMirroringMessageStreamTest.java +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Copyright 2026 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.storage; - -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyLong; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.UUID; -import java.util.concurrent.CompletableFuture; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.EnumSource; -import org.whispersystems.textsecuregcm.entities.MessageProtos; -import org.whispersystems.textsecuregcm.storage.foundationdb.FoundationDbMessageStream; -import org.whispersystems.textsecuregcm.util.UUIDUtil; -import reactor.adapter.JdkFlowAdapter; -import reactor.core.publisher.Flux; - -class AcknowledgementMirroringMessageStreamTest { - - private RedisDynamoDbMessageStream redisDynamoDbMessageStream; - private FoundationDbMessageStream foundationDbMessageStream; - - private AcknowledgementMirroringMessageStream mirroringMessageStream; - - private static long serialMessageTimestamp = 0; - - private enum UUIDVersion { - V4, - V8 - } - - @BeforeEach - void setUp() { - redisDynamoDbMessageStream = mock(RedisDynamoDbMessageStream.class); - - when(redisDynamoDbMessageStream.acknowledgeMessage(any(), anyLong())) - .thenReturn(CompletableFuture.completedFuture(null)); - - foundationDbMessageStream = mock(FoundationDbMessageStream.class); - - when(foundationDbMessageStream.acknowledgeMessage(any(), anyLong())) - .thenReturn(CompletableFuture.completedFuture(null)); - - mirroringMessageStream = new AcknowledgementMirroringMessageStream( - redisDynamoDbMessageStream, - foundationDbMessageStream); - } - - @ParameterizedTest - @EnumSource(UUIDVersion.class) - void acknowledgeMessages(final UUIDVersion uuidVersion) { - final List entries = new ArrayList<>(); - { - for (int i = 0; i < AcknowledgementMirroringMessageStream.FOUNDATIONDB_REQUEST_SIZE + 1; i++) { - entries.add(generateEnvelopeEntry(uuidVersion)); - } - - entries.add(new MessageStreamEntry.QueueEmpty()); - } - - when(redisDynamoDbMessageStream.getMessages()) - .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.fromIterable(entries))); - - when(foundationDbMessageStream.getMessages()) - .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.fromIterable( - uuidVersion == UUIDVersion.V8 ? entries : Collections.emptyList()))); - - JdkFlowAdapter.flowPublisherToFlux(mirroringMessageStream.getMessages()) - .doOnNext(entry -> { - if (entry instanceof MessageStreamEntry.Envelope(MessageProtos.Envelope message)) { - mirroringMessageStream.acknowledgeMessage(UUIDUtil.fromByteString(message.getServerGuid()), - message.getServerTimestamp()); - } - }) - .takeUntil(entry -> entry instanceof MessageStreamEntry.QueueEmpty) - .then() - .block(); - - if (uuidVersion == UUIDVersion.V8) { - entries.stream() - .filter(streamEntry -> streamEntry instanceof MessageStreamEntry.Envelope) - .map(streamEntry -> ((MessageStreamEntry.Envelope) streamEntry).message()) - .forEach(envelope -> verify(foundationDbMessageStream).acknowledgeMessage( - UUIDUtil.fromByteString(envelope.getServerGuid()), envelope.getServerTimestamp())); - } else { - verify(foundationDbMessageStream, never()).acknowledgeMessage(any(), anyLong()); - } - } - - @Test - void acknowledgeMessagesOverflow() { - final List entries = new ArrayList<>(); - { - for (int i = 0; i < AcknowledgementMirroringMessageStream.MAX_PENDING_ACKNOWLEDGEMENTS + 1; i++) { - entries.add(generateEnvelopeEntry(UUIDVersion.V8)); - } - - entries.add(new MessageStreamEntry.QueueEmpty()); - } - - when(redisDynamoDbMessageStream.getMessages()) - .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.fromIterable(entries))); - - when(foundationDbMessageStream.getMessages()) - .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.empty())); - - assertThrows(IllegalStateException.class, - () -> JdkFlowAdapter.flowPublisherToFlux(mirroringMessageStream.getMessages()) - .doOnNext(entry -> { - if (entry instanceof MessageStreamEntry.Envelope(MessageProtos.Envelope message)) { - mirroringMessageStream.acknowledgeMessage(UUIDUtil.fromByteString(message.getServerGuid()), - message.getServerTimestamp()); - } - }) - .takeUntil(entry -> entry instanceof MessageStreamEntry.QueueEmpty) - .then() - .block()); - - verify(foundationDbMessageStream, never()).acknowledgeMessage(any(), anyLong()); - } - - private static MessageStreamEntry.Envelope generateEnvelopeEntry(final UUIDVersion uuidVersion) { - final UUID messageGuid = switch (uuidVersion) { - case V4 -> UUID.randomUUID(); - case V8 -> MessageGuidUtil.generateRandomV8UUID(); - }; - - return new MessageStreamEntry.Envelope(MessageProtos.Envelope.newBuilder() - .setServerGuid(UUIDUtil.toByteString(messageGuid)) - .setServerTimestamp(serialMessageTimestamp++) - .build()); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/DeletionMirroringRedisDynamoDbMessageStreamTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/DeletionMirroringRedisDynamoDbMessageStreamTest.java new file mode 100644 index 000000000..08414df6b --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/DeletionMirroringRedisDynamoDbMessageStreamTest.java @@ -0,0 +1,68 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.List; +import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.storage.foundationdb.FoundationDbMessageStore; + +class DeletionMirroringRedisDynamoDbMessageStreamTest { + + private FoundationDbMessageStore foundationDbMessageStore; + private ExperimentEnrollmentManager experimentEnrollmentManager; + + private DeletionMirroringRedisDynamoDbMessageStream deletionMirroringRedisDynamoDbMessageStream; + + private static final AciServiceIdentifier ACCOUNT_IDENTIFIER = new AciServiceIdentifier(UUID.randomUUID()); + private static final byte DEVICE_ID = Device.PRIMARY_ID; + + @BeforeEach + void setUp() { + foundationDbMessageStore = mock(FoundationDbMessageStore.class); + experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class); + + deletionMirroringRedisDynamoDbMessageStream = new DeletionMirroringRedisDynamoDbMessageStream( + mock(RedisDynamoDbMessageStream.class), + foundationDbMessageStore, + experimentEnrollmentManager, + ACCOUNT_IDENTIFIER.uuid(), + DEVICE_ID); + } + + @ParameterizedTest + @MethodSource + void acknowledgeMessage(final boolean enrolled, final UUID messageGuid, final boolean expectFoundationDbDeletion) { + when(experimentEnrollmentManager.isEnrolled(any(UUID.class), eq(MessagesManager.MIRROR_DELETIONS_EXPERIMENT_NAME))) + .thenReturn(enrolled); + + deletionMirroringRedisDynamoDbMessageStream.acknowledgeMessage(messageGuid, System.currentTimeMillis()); + + verify(foundationDbMessageStore, times(expectFoundationDbDeletion ? 1 : 0)) + .delete(ACCOUNT_IDENTIFIER, DEVICE_ID, messageGuid); + } + + private static List acknowledgeMessage() { + return List.of( + Arguments.argumentSet("Not enrolled, v4 UUID", false, UUID.randomUUID(), false), + Arguments.argumentSet("Not enrolled, v8 UUID", false, MessageGuidUtil.generateRandomV8UUID(), false), + Arguments.argumentSet("Enrolled, v4 UUID", true, UUID.randomUUID(), false), + Arguments.argumentSet("Enrolled, v8 UUID", true, MessageGuidUtil.generateRandomV8UUID(), true) + ); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java index 212a5455e..1d7303ffe 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java @@ -452,6 +452,36 @@ class FoundationDbMessageStoreTest { Map.of(generateRandomAciForShard(0), Collections.emptyMap()))); } + @Test + void delete() { + final AciServiceIdentifier deletedMessageAci = new AciServiceIdentifier(UUID.randomUUID()); + final byte deletedMessageDeviceId = Device.PRIMARY_ID; + final MessageProtos.Envelope deletedMessage = generateRandomMessage(false); + + final AciServiceIdentifier retainedMessageAci = new AciServiceIdentifier(UUID.randomUUID()); + final byte retainedMessageDeviceId = Device.PRIMARY_ID; + final MessageProtos.Envelope retainedMessage = generateRandomMessage(false); + + final UUID deletedMessageGuid = + foundationDbMessageStore.insert(deletedMessageAci, Map.of(deletedMessageDeviceId, deletedMessage)).join() + .get(deletedMessageDeviceId).messageGuid().orElseThrow(); + + foundationDbMessageStore.insert(retainedMessageAci, Map.of(retainedMessageDeviceId, retainedMessage)).join(); + + assertArrayEquals(deletedMessage.toByteArray(), + getItemsInDeviceQueue(deletedMessageAci, deletedMessageDeviceId).getFirst().getValue()); + + assertArrayEquals(retainedMessage.toByteArray(), + getItemsInDeviceQueue(retainedMessageAci, retainedMessageDeviceId).getFirst().getValue()); + + foundationDbMessageStore.delete(deletedMessageAci, deletedMessageDeviceId, deletedMessageGuid).join(); + + assertTrue(getItemsInDeviceQueue(deletedMessageAci, deletedMessageDeviceId).isEmpty()); + + assertArrayEquals(retainedMessage.toByteArray(), + getItemsInDeviceQueue(retainedMessageAci, retainedMessageDeviceId).getFirst().getValue()); + } + @Test void clearAllForAccount() { final AciServiceIdentifier deletedAccountIdentifier = new AciServiceIdentifier(UUID.randomUUID()); @@ -523,9 +553,10 @@ class FoundationDbMessageStoreTest { .orElseThrow()) .toList(); - final MessageStream messageStream = - foundationDbMessageStore.getMessages(aci, Device.PRIMARY_ID, batchSize, numMessages, Util.NOOP); - + final Device device = new Device(); + device.setId(Device.PRIMARY_ID); + final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device, batchSize, numMessages, + Util.NOOP); final List retrievedEntries = new ArrayList<>(); StepVerifier.create(JdkFlowAdapter.flowPublisherToFlux(messageStream.getMessages())) .recordWith(() -> retrievedEntries) @@ -561,6 +592,9 @@ class FoundationDbMessageStoreTest { final AciServiceIdentifier aci = new AciServiceIdentifier(UUID.randomUUID()); final byte deviceId = Device.PRIMARY_ID; + final Device device = new Device(); + device.setId(deviceId); + final int messagesPerBatch = 8; final AtomicLong serialTimestamp = new AtomicLong(); @@ -574,7 +608,7 @@ class FoundationDbMessageStoreTest { final List liveDefaultEpochMessages = new ArrayList<>(); final List liveFutureEpochMessages = new ArrayList<>(); - final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, deviceId); + final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device); final List retrievedEntries = new ArrayList<>(); final CountDownLatch queueEmptyLatch = new CountDownLatch(1); @@ -655,7 +689,9 @@ class FoundationDbMessageStoreTest { } }); - final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, Device.PRIMARY_ID); + final Device device = new Device(); + device.setId(Device.PRIMARY_ID); + final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device); StepVerifier.create(JdkFlowAdapter.flowPublisherToFlux(messageStream.getMessages())) .expectNext(new MessageStreamEntry.Envelope(message1 .toBuilder() @@ -699,7 +735,9 @@ class FoundationDbMessageStoreTest { } }); - final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, Device.PRIMARY_ID); + final Device device = new Device(); + device.setId(Device.PRIMARY_ID); + final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device); // Initially only request a single message, then give the go ahead to the publisher. This verifies that we get the // queue empty signal when we consume the initial batch of messages even though the publisher keeps publishing in // the meantime. @@ -730,8 +768,10 @@ class FoundationDbMessageStoreTest { .orElseThrow()) .toList(); + final Device device = new Device(); + device.setId(Device.PRIMARY_ID); final CountDownLatch latch = new CountDownLatch(1); - final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, Device.PRIMARY_ID, + final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device, FoundationDbMessageStream.DEFAULT_MAX_MESSAGES_PER_SCAN, FoundationDbMessageStream.DEFAULT_MAX_UNACKNOWLEDGED_MESSAGES, latch::countDown); @@ -792,7 +832,12 @@ class FoundationDbMessageStoreTest { void acknowledgeMessagesEpochChange() throws InterruptedException { final AciServiceIdentifier aci = new AciServiceIdentifier(UUID.randomUUID()); final byte deviceId = Device.PRIMARY_ID; + + final Device device = new Device(); + device.setId(deviceId); + final int messagesPerBatch = 8; + final AtomicLong serialTimestamp = new AtomicLong(); generateAndInsertMessages(aci, deviceId, DEFAULT_EPOCH, messagesPerBatch, serialTimestamp::getAndIncrement); @@ -802,7 +847,7 @@ class FoundationDbMessageStoreTest { final CountDownLatch cleanupLatch = new CountDownLatch(1); final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, - deviceId, + device, FoundationDbMessageStream.DEFAULT_MAX_MESSAGES_PER_SCAN, FoundationDbMessageStream.DEFAULT_MAX_UNACKNOWLEDGED_MESSAGES, cleanupLatch::countDown); @@ -825,7 +870,7 @@ class FoundationDbMessageStoreTest { } { - final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, deviceId); + final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device); final List retrievedEntries = new ArrayList<>(); StepVerifier.create(JdkFlowAdapter.flowPublisherToFlux(messageStream.getMessages())) @@ -857,7 +902,9 @@ class FoundationDbMessageStoreTest { foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, generateRandomMessage(false))).join(); } - final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, Device.PRIMARY_ID, + final Device device = new Device(); + device.setId(Device.PRIMARY_ID); + final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device, FoundationDbMessageStream.DEFAULT_MAX_MESSAGES_PER_SCAN, maxUnacknowledgedMessages, Util.NOOP); @@ -902,7 +949,10 @@ class FoundationDbMessageStoreTest { } }); - final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, Device.PRIMARY_ID, + final Device device = new Device(); + device.setId(Device.PRIMARY_ID); + + final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device, FoundationDbMessageStream.DEFAULT_MAX_MESSAGES_PER_SCAN, FoundationDbMessageStream.DEFAULT_MAX_UNACKNOWLEDGED_MESSAGES, cleanUpLatch::countDown);