Implement FoundationDbMessageStream#acknowledgeMessage

This commit is contained in:
Ameya Lokare 2026-04-27 17:05:12 -07:00 committed by Jon Chambers
parent 68a5e4e8ee
commit defbc1c853
7 changed files with 414 additions and 101 deletions

View File

@ -4,29 +4,26 @@ import com.apple.foundationdb.Database;
import com.apple.foundationdb.KeySelector;
import com.apple.foundationdb.StreamingMode;
import com.apple.foundationdb.Transaction;
import com.google.protobuf.InvalidProtocolBufferException;
import com.apple.foundationdb.tuple.Versionstamp;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.storage.MessageStreamEntry;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import javax.annotation.Nullable;
import java.io.UncheckedIOException;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Supplier;
/// Publishes a message stream from a device queue in FoundationDB. Capable of publishing both a finite stream for
/// catching up to end-of-queue,and an infinite stream for live updates.
class FoundationDbMessagePublisher {
private final Database database;
private final MessageGuidCodec messageGuidCodec;
/// The maximum number of messages we will fetch per range query operation to avoid excessive memory consumption
private final int maxMessagesPerScan;
/// The end key at which we stop reading messages. For finite publisher, this is just past the end-of-queue key at the
@ -76,14 +73,17 @@ class FoundationDbMessagePublisher {
@Nullable private final byte[] messagesAvailableWatchKey;
/// Listener to watch for state machine transitions; used for testing.
private final BiConsumer<State, State> stateChangeListener;
private final Flux<MessageStreamEntry.Envelope> messagePublisher;
private final Flux<FoundationDbMessageStreamEntry.Message> messagePublisher;
/// Whether the publisher is finite or infinite.
private final boolean terminateOnQueueEmpty;
/// A supplier that returns a function to execute before we begin fetching a page; it is passed the transaction that
/// is started for the page fetch operations.
private final Supplier<Consumer<Transaction>> beforePageFetch;
/// Tracks the current state of the publisher state machine. Initial state presumes that messages are available in the queue.
private State state = State.MESSAGES_AVAILABLE;
/// Reference to the sink we publishes messages to.
private volatile FluxSink<MessageStreamEntry.Envelope> emitter;
private volatile FluxSink<FoundationDbMessageStreamEntry.Message> emitter;
/// Future that completes when the watch for {@link #messagesAvailableWatchKey} triggers.
private CompletableFuture<Void> watchFuture;
@ -93,19 +93,19 @@ class FoundationDbMessagePublisher {
final KeySelector beginKeyInclusive,
final KeySelector endKeyExclusive,
final Database database,
final MessageGuidCodec messageGuidCodec,
final int maxMessagesPerScan,
@Nullable final byte[] messagesAvailableWatchKey,
@Nullable final BiConsumer<State, State> stateChangeListener) {
@Nullable final BiConsumer<State, State> stateChangeListener,
final Supplier<Consumer<Transaction>> beforePageFetch) {
this.beginKeyCursor = beginKeyInclusive;
this.endKeyExclusive = endKeyExclusive;
this.database = database;
this.messageGuidCodec = messageGuidCodec;
this.maxMessagesPerScan = maxMessagesPerScan;
this.messagesAvailableWatchKey = messagesAvailableWatchKey;
this.terminateOnQueueEmpty = messagesAvailableWatchKey == null;
this.stateChangeListener = stateChangeListener != null ? stateChangeListener : (_, _) -> {};
this.beforePageFetch = beforePageFetch;
this.messagePublisher = Flux.create(emitter -> {
this.emitter = emitter;
emitter.onRequest(_ -> transitionStateOnEvent(Event.DEMAND_REQUESTED));
@ -121,16 +121,16 @@ class FoundationDbMessagePublisher {
final KeySelector beginKeyInclusive,
final KeySelector endKeyExclusive,
final Database database,
final MessageGuidCodec messageGuidCodec,
final int maxMessagesPerScan) {
final int maxMessagesPerScan,
final Supplier<Consumer<Transaction>> beforePageFetch) {
return new FoundationDbMessagePublisher(beginKeyInclusive,
endKeyExclusive,
database,
messageGuidCodec,
maxMessagesPerScan,
null,
null);
null,
beforePageFetch);
}
/// Creates a [FoundationDbMessagePublisher] that publishes a non-terminating stream of messages from a device queue.
@ -140,17 +140,17 @@ class FoundationDbMessagePublisher {
final KeySelector beginKeyInclusive,
final KeySelector endKeyExclusive,
final Database database,
final MessageGuidCodec messageGuidCodec,
final int maxMessagesPerScan,
final byte[] messagesAvailableWatchKey) {
final byte[] messagesAvailableWatchKey,
final Supplier<Consumer<Transaction>> beforePageFetch) {
return new FoundationDbMessagePublisher(beginKeyInclusive,
endKeyExclusive,
database,
messageGuidCodec,
maxMessagesPerScan,
messagesAvailableWatchKey,
null);
null,
beforePageFetch);
}
private synchronized void setState(final State newState, final Event event) {
@ -238,25 +238,32 @@ class FoundationDbMessagePublisher {
/// Fetch messages using a range query limiting batch size to [#maxMessagesPerScan]. If the query returns fewer than
/// [#maxMessagesPerScan], emit [Event#FETCHED_ALL_AVAILABLE_MESSAGES]. In the case of an infinite publisher, also set
/// a watch for new messages. Additionally, the cursor is updated so that we begin fetching from the right key on
/// a watch for new messages. Execute the function supplied by [#beforePageFetch] in the context of the transaction.
/// Additionally, the cursor is updated so that we begin fetching from the right key on
/// subsequent scans
///
/// @return a future of a list of [MessageStreamEntry] with a max size of [#maxMessagesPerScan]
private CompletableFuture<List<MessageStreamEntry.Envelope>> getMessagesBatch() {
return database.runAsync(transaction -> getItemsInRange(transaction, messageGuidCodec, beginKeyCursor, endKeyExclusive, maxMessagesPerScan)
.thenApply(lastKeyReadAndItems -> {
// Set our beginning key to just past the last key read so that we're ready for our next fetch
lastKeyReadAndItems.first().ifPresent(lastKeyRead -> beginKeyCursor = KeySelector.firstGreaterThan(lastKeyRead));
/// @return a future of a list of [FoundationDbMessageStreamEntry.Message] with a max size of [#maxMessagesPerScan]
private CompletableFuture<List<FoundationDbMessageStreamEntry.Message>> getMessagesBatch() {
final Consumer<Transaction> doBeforePageFetch = beforePageFetch.get();
return database.runAsync(transaction -> {
doBeforePageFetch.accept(transaction);
return getItemsInRange(transaction, beginKeyCursor, endKeyExclusive, maxMessagesPerScan)
.thenApply(lastKeyReadAndItems -> {
// Set our beginning key to just past the last key read so that we're ready for our next fetch
lastKeyReadAndItems.first()
.ifPresent(lastKeyRead -> beginKeyCursor = KeySelector.firstGreaterThan(lastKeyRead));
final List<MessageStreamEntry.Envelope> items = lastKeyReadAndItems.second();
if (items.size() < maxMessagesPerScan) {
transitionStateOnEvent(Event.FETCHED_ALL_AVAILABLE_MESSAGES);
if (!terminateOnQueueEmpty) {
setWatch(transaction);
}
}
return items;
})
final List<FoundationDbMessageStreamEntry.Message> items = lastKeyReadAndItems.second();
if (items.size() < maxMessagesPerScan) {
transitionStateOnEvent(Event.FETCHED_ALL_AVAILABLE_MESSAGES);
if (!terminateOnQueueEmpty) {
setWatch(transaction);
}
}
return items;
});
}
);
}
@ -267,9 +274,8 @@ class FoundationDbMessagePublisher {
/// @param endExclusive the range end key (exclusive)
/// @param maxMessagesPerScan maximum number of messages to return in the fetch query
/// @return the last key read (if there were non-zero number of messages read) and the list of messages read
private static CompletableFuture<Pair<Optional<byte[]>, List<MessageStreamEntry.Envelope>>> getItemsInRange(
private CompletableFuture<Pair<Optional<byte[]>, List<FoundationDbMessageStreamEntry.Message>>> getItemsInRange(
final Transaction transaction,
final MessageGuidCodec messageGuidCodec,
final KeySelector beginInclusive,
final KeySelector endExclusive,
final int maxMessagesPerScan) {
@ -278,16 +284,10 @@ class FoundationDbMessagePublisher {
final Optional<byte[]> lastKeyRead = keyValues.isEmpty()
? Optional.empty()
: Optional.of(keyValues.getLast().getKey());
final List<MessageStreamEntry.Envelope> messages = keyValues.stream()
final List<FoundationDbMessageStreamEntry.Message> messages = keyValues.stream()
.map(keyValue -> {
try {
return new MessageStreamEntry.Envelope(MessageProtos.Envelope.parseFrom(keyValue.getValue())
.toBuilder()
.setServerGuidBinary(UUIDUtil.toByteString(messageGuidCodec.encodeMessageGuid(FoundationDbMessageStore.getVersionstamp(keyValue.getKey()))))
.build());
} catch (final InvalidProtocolBufferException e) {
throw new UncheckedIOException(e);
}
final Versionstamp versionstamp = FoundationDbMessageStore.getVersionstamp(keyValue.getKey());
return new FoundationDbMessageStreamEntry.Message(versionstamp, keyValue.getValue());
})
.toList();
return new Pair<>(lastKeyRead, messages);
@ -317,7 +317,7 @@ class FoundationDbMessagePublisher {
/// Get the stream of messages.
///
/// @return [Flux] of messages
public Flux<MessageStreamEntry.Envelope> getMessages() {
public Flux<FoundationDbMessageStreamEntry.Message> getMessages() {
return this.messagePublisher;
}

View File

@ -25,6 +25,7 @@ import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessageStream;
import org.whispersystems.textsecuregcm.util.Conversions;
import org.whispersystems.textsecuregcm.util.Util;
/// An implementation of a message store backed by FoundationDB.
///
@ -261,16 +262,19 @@ public class FoundationDbMessageStore {
}
public MessageStream getMessages(final AciServiceIdentifier aci, final Device destinationDevice) {
return getMessages(aci, destinationDevice, FoundationDbMessageStream.DEFAULT_MAX_MESSAGES_PER_SCAN);
return getMessages(aci, destinationDevice, FoundationDbMessageStream.DEFAULT_MAX_MESSAGES_PER_SCAN,
Util.NOOP);
}
@VisibleForTesting
MessageStream getMessages(final AciServiceIdentifier aci, final Device destinationDevice, final int maxMessagesPerScan) {
MessageStream getMessages(final AciServiceIdentifier aci, final Device destinationDevice,
final int maxMessagesPerScan, final Runnable doAfterCleanup) {
return new FoundationDbMessageStream(getDeviceQueueSubspace(aci, destinationDevice.getId()),
getMessagesAvailableWatchKey(aci),
getShardForAci(aci),
new MessageGuidCodec(aci.uuid(), destinationDevice.getId(), versionstampUUIDCipher),
maxMessagesPerScan);
maxMessagesPerScan,
doAfterCleanup);
}
static Versionstamp getVersionstamp(final byte[] messageKey) {

View File

@ -3,14 +3,30 @@ package org.whispersystems.textsecuregcm.storage.foundationdb;
import com.apple.foundationdb.Database;
import com.apple.foundationdb.KeySelector;
import com.apple.foundationdb.StreamingMode;
import com.apple.foundationdb.Transaction;
import com.apple.foundationdb.subspace.Subspace;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.NavigableMap;
import java.util.Optional;
import java.util.TreeMap;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Flow;
import java.util.function.Consumer;
import java.util.function.Supplier;
import com.apple.foundationdb.tuple.ByteArrayUtil;
import com.apple.foundationdb.tuple.Tuple;
import com.apple.foundationdb.tuple.Versionstamp;
import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.storage.MessageStream;
import org.whispersystems.textsecuregcm.storage.MessageStreamEntry;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import reactor.adapter.JdkFlowAdapter;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
@ -25,14 +41,21 @@ public class FoundationDbMessageStream implements MessageStream {
/// The maximum number of messages we will fetch per range query operation to avoid excessive memory consumption
private final int maxMessagesPerScan;
private final Flow.Publisher<MessageStreamEntry> messageStreamPublisher;
private final Runnable doAfterCleanup;
/// Map of versionstamp -> acknowledged? to keep track of acknowledged versionstamp ranges to clear
private final NavigableMap<Versionstamp, Boolean> sentVersionstamps = new TreeMap<>();
static final int DEFAULT_MAX_MESSAGES_PER_SCAN = 1024;
private static final Logger LOGGER = LoggerFactory.getLogger(FoundationDbMessageStream.class);
FoundationDbMessageStream(final Subspace deviceQueueSubspace,
final byte[] messagesAvailableWatchKey,
final Database database,
final MessageGuidCodec messageGuidCodec,
final int maxMessagesPerScan) {
final int maxMessagesPerScan,
final Runnable doAfterCleanup) {
this.deviceQueueSubspace = deviceQueueSubspace;
this.messagesAvailableWatchKey = messagesAvailableWatchKey;
@ -40,6 +63,7 @@ public class FoundationDbMessageStream implements MessageStream {
this.messageGuidCodec = messageGuidCodec;
this.maxMessagesPerScan = maxMessagesPerScan;
this.messageStreamPublisher = JdkFlowAdapter.publisherToFlowPublisher(createMessagePublisher());
this.doAfterCleanup = doAfterCleanup;
}
@Override
@ -47,44 +71,62 @@ public class FoundationDbMessageStream implements MessageStream {
return this.messageStreamPublisher;
}
/// Create a message publisher
///
/// @return a Flux of {@link MessageStreamEntry} fetched from FoundationDB
/// @implNote The message publisher is stitched together by concatenating:
/// 1. **A finite message publisher**: On initial request, we record the current end-of-queue key in the device mailbox.
/// Then, we fetch all messages in order until the recorded key and finally complete the stream
/// 2. **A queue-empty signal** is emitted
/// 3. **An infinite message publisher**: We start reading from where the finite publisher left off. When all messages
/// are read, we wait for new messages, publish them, then wait again in a loop forever (until the flux is canceled
/// explicitly or due to an error). This is accomplished by setting a FoundationDB [watch](https://github.com/apple/foundationdb/wiki/An-Overview-how-Watches-Work)
/// on [#messagesAvailableWatchKey] which is updated when a new message is available.
/// See [FoundationDbMessageStore] for more details on the message insert process.
private Flux<MessageStreamEntry> createMessagePublisher() {
return Mono.fromFuture(this::getEndOfQueueKeyExclusive)
/// Create a message publisher
///
/// @return a Flux of {@link MessageStreamEntry} fetched from FoundationDB
/// @implNote turns the stream of [FoundationDbMessageStreamEntry] into [MessageStreamEntry], but taps into the stream
/// first to keep track of versionstamps sent to the client.
private Flux<MessageStreamEntry> createMessagePublisher() {
return createFoundationDbMessagePublisher()
.doOnNext(messageStreamEntry -> {
if (messageStreamEntry instanceof final FoundationDbMessageStreamEntry.Message message) {
onVersionstampSent(message.versionstamp());
}
})
.map(fdbMessageStreamEntry -> fdbMessageStreamEntry.toMessageStreamEntry(messageGuidCodec))
.doFinally(_ -> flushAllAcknowledgedMessages().thenRun(doAfterCleanup));
}
/// Create a message publisher that fetches messages from FoundationDB
///
/// @return a Flux of [FoundationDbMessageStreamEntry] fetched from FoundationDB
/// @implNote The message publisher is stitched together by concatenating:
/// 1. **A finite message publisher**: On initial request, we record the current end-of-queue key in the device mailbox.
/// Then, we fetch all messages in order until the recorded key and finally complete the stream
/// 2. **A queue-empty signal** is emitted
/// 3. **An infinite message publisher**: We start reading from where the finite publisher left off. When all messages
/// are read, we wait for new messages, publish them, then wait again in a loop forever (until the flux is canceled
/// explicitly or due to an error). This is accomplished by setting a FoundationDB [watch](https://github.com/apple/foundationdb/wiki/An-Overview-how-Watches-Work)
/// on [#messagesAvailableWatchKey] which is updated when a new message is available.
/// See [FoundationDbMessageStore] for more details on the message insert process.
private Flux<FoundationDbMessageStreamEntry> createFoundationDbMessagePublisher() {
return Mono.fromFuture(this::getEndOfQueueKeyExclusive)
.flatMapMany(maybeEndOfQueueKeyExclusive -> {
final Flux<MessageStreamEntry.Envelope> finitePublisher = maybeEndOfQueueKeyExclusive
final Flux<FoundationDbMessageStreamEntry.Message> finitePublisher = maybeEndOfQueueKeyExclusive
.map(endOfQueueKeyExclusive -> FoundationDbMessagePublisher.createFinitePublisher(
KeySelector.firstGreaterOrEqual(deviceQueueSubspace.range().begin),
endOfQueueKeyExclusive, database, messageGuidCodec, maxMessagesPerScan).getMessages())
endOfQueueKeyExclusive, database, maxMessagesPerScan, this::clearAcknowledgedMessages).getMessages())
.orElseGet(Flux::empty);
final KeySelector infinitePublisherBeginKey = maybeEndOfQueueKeyExclusive.orElseGet(
() -> KeySelector.firstGreaterOrEqual(deviceQueueSubspace.range().begin));
final Flux<MessageStreamEntry.Envelope> infinitePublisher = FoundationDbMessagePublisher.createInfinitePublisher(
final Flux<FoundationDbMessageStreamEntry.Message> infinitePublisher = FoundationDbMessagePublisher.createInfinitePublisher(
infinitePublisherBeginKey, KeySelector.firstGreaterThan(deviceQueueSubspace.range().end),
database, messageGuidCodec, maxMessagesPerScan, messagesAvailableWatchKey).getMessages();
database, maxMessagesPerScan, messagesAvailableWatchKey, this::clearAcknowledgedMessages).getMessages();
return Flux.concat(
finitePublisher,
Mono.just(new MessageStreamEntry.QueueEmpty()),
Mono.just(new FoundationDbMessageStreamEntry.QueueEmpty()),
infinitePublisher
);
});
}
/// Gets a [KeySelector] for the first key greater than the current greatest key in the device queue. This allows
/// us to query keys up to and including the greatest key, and sets us up to begin reading from the next key in
/// a subsequent scan.
/// @return a [KeySelector] for the first key greater than the current greatest key in the device queue.
private CompletableFuture<Optional<KeySelector>> getEndOfQueueKeyExclusive() {
/// Gets a [KeySelector] for the first key greater than the current greatest key in the device queue. This allows us
/// to query keys up to and including the greatest key, and sets us up to begin reading from the next key in a
/// subsequent scan.
///
/// @return a [KeySelector] for the first key greater than the current greatest key in the device queue.
private CompletableFuture<Optional<KeySelector>> getEndOfQueueKeyExclusive() {
return database.runAsync(
transaction -> transaction.getRange(deviceQueueSubspace.range(), 1, true, StreamingMode.EXACT).asList()
.thenApply(items -> {
@ -98,6 +140,90 @@ public class FoundationDbMessageStream implements MessageStream {
@Override
public CompletableFuture<Void> acknowledgeMessage(final MessageProtos.Envelope message) {
throw new UnsupportedOperationException("Not implemented");
handleAcknowledged(messageGuidCodec.decodeMessageGuid(UUIDUtil.fromByteString(message.getServerGuidBinary())));
return CompletableFuture.completedFuture(null);
}
@VisibleForTesting
synchronized void handleAcknowledged(final Versionstamp versionstamp) {
// If we don't know about this versionstamp, or it's already acknowledged, there's nothing to do
if (!sentVersionstamps.containsKey(versionstamp) || sentVersionstamps.get(versionstamp)) {
return;
}
sentVersionstamps.put(versionstamp, true);
}
@VisibleForTesting
synchronized void onVersionstampSent(final Versionstamp versionstamp) {
sentVersionstamps.put(versionstamp, false);
}
/// Clear the versionstamp range (startInclusive, endInclusive) in a single FoundationDB operation.
///
/// @param transaction The FoundationDB transaction in which to perform the range clear
/// @param startInclusive The starting versionstamp of the range to be cleared (inclusive)
/// @param endInclusive The ending versionstamp of the range to be cleared (inclusive)
private void clearRange(final Transaction transaction, final Versionstamp startInclusive, final Versionstamp endInclusive) {
final byte[] startKeyInclusive = deviceQueueSubspace.pack(Tuple.from(startInclusive));
final byte[] endKeyExclusive = ByteArrayUtil.keyAfter(deviceQueueSubspace.pack(Tuple.from(endInclusive)));
transaction.clear(startKeyInclusive, endKeyExclusive);
}
/// Clear all outstanding acknowledged messages. Called when the stream ends
private CompletableFuture<Void> flushAllAcknowledgedMessages() {
final Consumer<Transaction> clearAllAcknowlegedMessagedConsumer = clearAcknowledgedMessages();
return database.runAsync(transaction -> {
clearAllAcknowlegedMessagedConsumer.accept(transaction);
return CompletableFuture.completedFuture((Void) null);
})
.whenComplete((_, throwable) -> {
if (throwable != null) {
LOGGER.warn("Failed to clear acknowledged messages", throwable);
}
});
}
private synchronized Consumer<Transaction> clearAcknowledgedMessages() {
final List<Pair<Versionstamp, Versionstamp>> flushableRanges = computeFlushableRanges();
flushableRanges.forEach(range -> sentVersionstamps.subMap(range.first(), true, range.second(), true).clear());
return transaction -> flushableRanges.forEach(range -> clearRange(transaction, range.first(), range.second()));
}
/// Computes a list of acknowledged contiguous versionstamp ranges that can be cleared from the database.
///
/// @return a list of acknowledged contiguous versionstamp ranges that can be cleared from the database.
@VisibleForTesting
synchronized List<Pair<Versionstamp, Versionstamp>> computeFlushableRanges() {
final List<Pair<Versionstamp, Versionstamp>> flushableRanges = new ArrayList<>();
Versionstamp startInclusive = null;
Versionstamp endInclusive = null;
for (final Map.Entry<Versionstamp, Boolean> entry : sentVersionstamps.entrySet()) {
if (entry.getValue()) {
// Message is acknowledged, so we can either start tracking a new range if we aren't already, or extend our
// current range
if (startInclusive == null) {
startInclusive = entry.getKey();
}
endInclusive = entry.getKey();
} else {
// Message is un-acknowledged, which means either it is a "range-breaker" or we never started tracking an
// acknowledged range. Mark the currently tracked range flushable, if it exists.
if (startInclusive != null) {
assert endInclusive != null;
flushableRanges.add(new Pair<>(startInclusive, endInclusive));
startInclusive = null;
}
}
}
if (startInclusive != null) {
assert endInclusive != null;
flushableRanges.add(new Pair<>(startInclusive, endInclusive));
}
return flushableRanges;
}
}

View File

@ -0,0 +1,33 @@
package org.whispersystems.textsecuregcm.storage.foundationdb;
import com.apple.foundationdb.tuple.Versionstamp;
import com.google.protobuf.InvalidProtocolBufferException;
import java.io.UncheckedIOException;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.storage.MessageStreamEntry;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
sealed interface FoundationDbMessageStreamEntry permits FoundationDbMessageStreamEntry.Message,
FoundationDbMessageStreamEntry.QueueEmpty {
record Message(Versionstamp versionstamp, byte[] payload) implements FoundationDbMessageStreamEntry {}
record QueueEmpty() implements FoundationDbMessageStreamEntry {}
default MessageStreamEntry toMessageStreamEntry(final MessageGuidCodec messageGuidCodec) {
return switch (this) {
case Message(final Versionstamp versionstamp, final byte[] payload): {
try {
yield new MessageStreamEntry.Envelope(MessageProtos.Envelope.parseFrom(payload)
.toBuilder()
.setServerGuidBinary(UUIDUtil.toByteString(messageGuidCodec.encodeMessageGuid(versionstamp)))
.build());
} catch (final InvalidProtocolBufferException e) {
throw new UncheckedIOException(e);
}
}
case QueueEmpty():
yield new MessageStreamEntry.QueueEmpty();
};
}
}

View File

@ -23,7 +23,9 @@ import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import com.apple.foundationdb.tuple.Tuple;
import com.apple.foundationdb.tuple.Versionstamp;
import com.google.protobuf.InvalidProtocolBufferException;
@ -34,7 +36,6 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessageStreamEntry;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import reactor.test.StepVerifier;
/// NOTE: most of the happy-path test cases are already covered in {@link FoundationDbMessageStoreTest}, this test
@ -42,7 +43,6 @@ import reactor.test.StepVerifier;
class FoundationDbMessagePublisherTest {
private Database database;
private MessageGuidCodec messageGuidCodec;
private List<FoundationDbMessagePublisher.State> stateTransitions;
private static final AciServiceIdentifier SERVICE_IDENTIFIER = new AciServiceIdentifier(UUID.randomUUID());
@ -53,6 +53,8 @@ class FoundationDbMessagePublisherTest {
private static final byte[] MESSAGES_AVAILABLE_WATCH_KEY =
FoundationDbMessageStore.getMessagesAvailableWatchKey(SERVICE_IDENTIFIER);
private static final Supplier<Consumer<Transaction>> NOOP_ON_PAGE_FETCH = () -> _ -> {};
@BeforeEach
void setUp() {
database = mock(Database.class);
@ -60,10 +62,6 @@ class FoundationDbMessagePublisherTest {
final byte[] messageGuidCodecKey = new byte[16];
new SecureRandom().nextBytes(messageGuidCodecKey);
messageGuidCodec = new MessageGuidCodec(SERVICE_IDENTIFIER.uuid(),
Device.PRIMARY_ID,
new VersionstampUUIDCipher(0, messageGuidCodecKey));
}
@Test
@ -99,10 +97,10 @@ class FoundationDbMessagePublisherTest {
KeySelector.firstGreaterOrEqual(SUBSPACE_RANGE.begin),
KeySelector.firstGreaterOrEqual(SUBSPACE_RANGE.end),
database,
messageGuidCodec,
2, // With 3 messages and batch size set to 2, we'll need to grab 2 batches.
null,
(_, newState) -> stateTransitions.add(newState)
(_, newState) -> stateTransitions.add(newState),
NOOP_ON_PAGE_FETCH
);
StepVerifier.create(finitePublisher.getMessages())
@ -125,7 +123,7 @@ class FoundationDbMessagePublisherTest {
@Test
@SuppressWarnings({"unchecked", "resource"})
void infinitePublisher() throws InvalidProtocolBufferException {
void infinitePublisher() {
final MessageProtos.Envelope message1 = FoundationDbMessageStoreTest.generateRandomMessage(false);
final MessageProtos.Envelope message2 = FoundationDbMessageStoreTest.generateRandomMessage(false);
final MessageProtos.Envelope message3 = FoundationDbMessageStoreTest.generateRandomMessage(false);
@ -170,7 +168,6 @@ class FoundationDbMessagePublisherTest {
KeySelector.firstGreaterOrEqual(SUBSPACE_RANGE.begin),
KeySelector.firstGreaterOrEqual(new byte[]{(byte) 10}),
database,
messageGuidCodec,
2,
MESSAGES_AVAILABLE_WATCH_KEY,
(oldState, newState) -> {
@ -187,7 +184,8 @@ class FoundationDbMessagePublisherTest {
watchFuture2.complete(null);
}
}
}
},
NOOP_ON_PAGE_FETCH
);
StepVerifier.create(infinitePublisher.getMessages())
@ -252,7 +250,6 @@ class FoundationDbMessagePublisherTest {
KeySelector.firstGreaterOrEqual(SUBSPACE_RANGE.begin),
KeySelector.firstGreaterOrEqual(new byte[]{(byte) 10}),
database,
messageGuidCodec,
2,
MESSAGES_AVAILABLE_WATCH_KEY,
(oldState, newState) -> {
@ -267,7 +264,8 @@ class FoundationDbMessagePublisherTest {
.count() == 1) {
watchFuture1.complete(null);
}
}
},
NOOP_ON_PAGE_FETCH
);
StepVerifier.create(infinitePublisher.getMessages())
@ -298,9 +296,9 @@ class FoundationDbMessagePublisherTest {
KeySelector.firstGreaterOrEqual(SUBSPACE_RANGE.begin),
KeySelector.firstGreaterThan(SUBSPACE_RANGE.end),
database,
messageGuidCodec,
100,
MESSAGES_AVAILABLE_WATCH_KEY);
MESSAGES_AVAILABLE_WATCH_KEY,
() -> _ -> {});
final MessageProtos.Envelope message = FoundationDbMessageStoreTest.generateRandomMessage(false);
final Transaction transaction = mock(Transaction.class);
final KeyValue keyValue = mockKeyValue((byte) 5, message);
@ -326,12 +324,8 @@ class FoundationDbMessagePublisherTest {
verify(watchFuture).cancel(true);
}
private MessageStreamEntry.Envelope getExpectedMessageStreamEntry(final KeyValue keyValue)
throws InvalidProtocolBufferException {
return new MessageStreamEntry.Envelope(MessageProtos.Envelope.parseFrom(keyValue.getValue())
.toBuilder()
.setServerGuidBinary(UUIDUtil.toByteString(messageGuidCodec.encodeMessageGuid(FoundationDbMessageStore.getVersionstamp(keyValue.getKey()))))
.build());
private FoundationDbMessageStreamEntry.Message getExpectedMessageStreamEntry(final KeyValue keyValue) {
return new FoundationDbMessageStreamEntry.Message(FoundationDbMessageStore.getVersionstamp(keyValue.getKey()), keyValue.getValue());
}
private static KeyValue mockKeyValue(final byte key, final MessageProtos.Envelope message) {

View File

@ -4,6 +4,7 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
@ -31,9 +32,11 @@ import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.stream.Collectors;
@ -56,8 +59,10 @@ import org.whispersystems.textsecuregcm.storage.MessageStreamEntry;
import org.whispersystems.textsecuregcm.util.Conversions;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.textsecuregcm.util.Util;
import reactor.adapter.JdkFlowAdapter;
import reactor.test.StepVerifier;
import javax.annotation.Nullable;
@Timeout(value = 5, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
class FoundationDbMessageStoreTest {
@ -404,7 +409,8 @@ class FoundationDbMessageStoreTest {
final Device device = new Device();
device.setId(Device.PRIMARY_ID);
final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device, batchSize);
final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device, batchSize,
Util.NOOP);
final List<MessageStreamEntry> retrievedEntries = new ArrayList<>();
StepVerifier.create(JdkFlowAdapter.flowPublisherToFlux(messageStream.getMessages()))
.recordWith(() -> retrievedEntries)
@ -550,6 +556,82 @@ class FoundationDbMessageStoreTest {
.verifyTimeout(Duration.ofSeconds(3));
}
@ParameterizedTest
@MethodSource
void acknowledgeMessages(final int numMessages, final Set<Integer> unacknowledgedMessages)
throws InterruptedException {
final AciServiceIdentifier aci = new AciServiceIdentifier(UUID.randomUUID());
final MessageGuidCodec messageGuidCodec =
new MessageGuidCodec(aci.uuid(), Device.PRIMARY_ID, versionstampUUIDCipher);
writePresenceKey(aci, Device.PRIMARY_ID, 1, 5L);
final List<Versionstamp> versionstamps = IntStream.range(0, numMessages)
.mapToObj(
_ -> foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, generateRandomMessage(false))).join()
.get(Device.PRIMARY_ID)
.versionstamp()
.orElseThrow())
.toList();
final Device device = new Device();
device.setId(Device.PRIMARY_ID);
final CountDownLatch latch = new CountDownLatch(1);
final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device,
FoundationDbMessageStream.DEFAULT_MAX_MESSAGES_PER_SCAN,
latch::countDown);
final List<CompletableFuture<Void>> acknowledgeFutures = new ArrayList<>();
final AtomicInteger messageCounter = new AtomicInteger(0);
StepVerifier.create(JdkFlowAdapter.flowPublisherToFlux(messageStream.getMessages())
.doOnNext(entry -> {
final int messageNum = messageCounter.getAndIncrement();
if (!unacknowledgedMessages.contains(messageNum) && entry instanceof MessageStreamEntry.Envelope(final MessageProtos.Envelope message)) {
acknowledgeFutures.add(messageStream.acknowledgeMessage(message));
}
}))
.expectNextCount(numMessages)
.expectNext(new MessageStreamEntry.QueueEmpty())
.verifyTimeout(Duration.ofSeconds(1));
CompletableFuture.allOf(acknowledgeFutures.toArray(CompletableFuture[]::new)).join();
final List<Versionstamp> expectedDeletedVersionstamps = IntStream.range(0, numMessages)
.filter(i -> !unacknowledgedMessages.contains(i))
.mapToObj(versionstamps::get)
.toList();
// Clean up can take a bit after subscription cancellation, so wait for the countdown latch to complete
if (!expectedDeletedVersionstamps.isEmpty()) {
assertTrue(latch.await(1, TimeUnit.SECONDS));
expectedDeletedVersionstamps.forEach(
versionstamp -> assertNull(getMessageByVersionstamp(aci, Device.PRIMARY_ID, versionstamp)));
}
// Expect that the unacknowledged messages are re-delivered when we connect again.
final List<Versionstamp> expectedRedeliveredVersionstamps = IntStream.range(0, numMessages)
.filter(unacknowledgedMessages::contains)
.mapToObj(versionstamps::get)
.toList();
final List<MessageStreamEntry> retrievedEntries = new ArrayList<>();
StepVerifier.create(JdkFlowAdapter.flowPublisherToFlux(messageStream.getMessages()))
.recordWith(() -> retrievedEntries)
.expectNextCount(expectedRedeliveredVersionstamps.size())
.expectNext(new MessageStreamEntry.QueueEmpty())
.verifyTimeout(Duration.ofSeconds(1));
assertEquals(expectedRedeliveredVersionstamps, retrievedEntries.stream()
.filter(e -> e instanceof MessageStreamEntry.Envelope)
.map(e -> messageGuidCodec.decodeMessageGuid(UUIDUtil.fromByteString(((MessageStreamEntry.Envelope) e).message().getServerGuidBinary())))
.toList());
}
static Stream<Arguments> acknowledgeMessages() {
return Stream.of(
Arguments.argumentSet("Single acknowledged message", 1, Collections.emptySet()),
Arguments.argumentSet("Multiple messages, all acknowledged", 16, Collections.emptySet()),
Arguments.argumentSet("Multiple messages, single unacknowledged", 16, Set.of(3)),
Arguments.argumentSet("Multiple messages with range-breakers", 16, Set.of(3, 7, 8, 9, 12))
);
}
static MessageProtos.Envelope generateRandomMessage(final boolean ephemeral) {
return generateRandomMessage(ephemeral, 16);
}
@ -561,10 +643,11 @@ class FoundationDbMessageStoreTest {
.build();
}
@Nullable
private byte[] getMessageByVersionstamp(final AciServiceIdentifier aci, final byte deviceId,
final Versionstamp versionstamp) {
return foundationDbMessageStore.getShardForAci(aci).read(transaction -> {
final byte[] key = foundationDbMessageStore.getDeviceQueueSubspace(aci, deviceId)
final byte[] key = FoundationDbMessageStore.getDeviceQueueSubspace(aci, deviceId)
.pack(Tuple.from(versionstamp));
return transaction.get(key);
}).join();
@ -572,7 +655,7 @@ class FoundationDbMessageStoreTest {
private Optional<Versionstamp> getMessagesAvailableWatch(final AciServiceIdentifier aci) {
return foundationDbMessageStore.getShardForAci(aci)
.read(transaction -> transaction.get(foundationDbMessageStore.getMessagesAvailableWatchKey(aci))
.read(transaction -> transaction.get(FoundationDbMessageStore.getMessagesAvailableWatchKey(aci))
.thenApply(value -> value == null ? null : Tuple.fromBytes(value).getVersionstamp(0))
.thenApply(Optional::ofNullable))
.join();
@ -609,7 +692,7 @@ class FoundationDbMessageStoreTest {
private List<KeyValue> getItemsInDeviceQueue(final AciServiceIdentifier aci, final byte deviceId) {
return foundationDbMessageStore.getShardForAci(aci).readAsync(transaction -> AsyncUtil.collect(transaction.getRange(
foundationDbMessageStore.getDeviceQueueSubspace(aci, deviceId).range()))).join();
FoundationDbMessageStore.getDeviceQueueSubspace(aci, deviceId).range()))).join();
}
}

View File

@ -0,0 +1,73 @@
package org.whispersystems.textsecuregcm.storage.foundationdb;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.mock;
import com.apple.foundationdb.Database;
import com.apple.foundationdb.subspace.Subspace;
import com.apple.foundationdb.tuple.Versionstamp;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Collections;
import java.util.List;
import java.util.stream.Stream;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.Util;
class FoundationDbMessageStreamTest {
@ParameterizedTest
@MethodSource
void computeFlushableRanges(final List<Integer> versionstampsSent, final List<Integer> versionstampsAcknowledged,
final List<Pair<Integer, Integer>> expectedRanges) {
final FoundationDbMessageStream foundationDbMessageStream = new FoundationDbMessageStream(
mock(Subspace.class),
new byte[]{0},
mock(Database.class),
mock(MessageGuidCodec.class),
100,
Util.NOOP
);
versionstampsSent.forEach(
versionstamp -> foundationDbMessageStream.onVersionstampSent(versionstampFromInt(versionstamp)));
versionstampsAcknowledged.forEach(
versionstamp -> foundationDbMessageStream.handleAcknowledged(versionstampFromInt(versionstamp)));
final List<Pair<Versionstamp, Versionstamp>> flushableRanges = foundationDbMessageStream.computeFlushableRanges();
assertEquals(expectedRanges
.stream()
.map(range -> new Pair<>(versionstampFromInt(range.first()), versionstampFromInt(range.second())))
.toList(),
flushableRanges);
}
static Stream<Arguments> computeFlushableRanges() {
return Stream.of(
Arguments.argumentSet("Single acknowledged message range", List.of(1, 2, 3), List.of(1),
List.of(new Pair<>(1, 1))),
Arguments.argumentSet("No acknowledged messages", List.of(1, 2, 3), Collections.emptyList(),
Collections.emptyList()),
Arguments.argumentSet("Fully acknowledged message range", List.of(1, 2, 3, 4),
List.of(1, 2, 3, 4), List.of(new Pair<>(1, 4))),
Arguments.argumentSet("Not all messages acknowledged", List.of(1, 2, 3, 4),
List.of(1, 2, 3), List.of(new Pair<>(1, 3))),
Arguments.argumentSet("Out-of-order acknowledged ranges", List.of(1, 2, 3, 4, 5, 6, 7, 8),
List.of(1, 3, 4, 5, 7, 8), List.of(new Pair<>(1, 1), new Pair<>(3, 5), new Pair<>(7, 8))),
Arguments.argumentSet("Out-of-order acknowledged single messages", List.of(1, 2, 3, 4, 5, 6, 7, 8),
List.of(1, 3, 5), List.of(new Pair<>(1, 1), new Pair<>(3, 3), new Pair<>(5, 5)))
);
}
private static Versionstamp versionstampFromInt(final int version) {
final ByteBuffer buf = ByteBuffer.allocate(10).order(ByteOrder.BIG_ENDIAN);
buf.putLong(version); // 8 bytes: transaction version
buf.putShort((short) 2); // 2 bytes: batch order within transaction
return Versionstamp.complete(buf.array());
}
}