address review comments

Co-authored-by: Jon Chambers <63609320+jon-signal@users.noreply.github.com>
This commit is contained in:
Jonathan Klabunde Tomer 2026-06-22 11:04:58 -07:00 committed by Chris Eager
parent f4e16676c9
commit 671a6e1d7c
4 changed files with 81 additions and 57 deletions

View File

@ -8,9 +8,6 @@ import com.apple.foundationdb.subspace.Subspace;
import com.apple.foundationdb.tuple.Tuple;
import com.apple.foundationdb.tuple.Versionstamp;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Multimap;
import com.google.common.collect.MultimapBuilder;
import com.google.common.collect.Multimaps;
import com.google.common.hash.Hashing;
import io.dropwizard.util.DataSize;
import io.micrometer.core.instrument.Counter;
@ -21,6 +18,7 @@ import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@ -37,8 +35,6 @@ import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessageStream;
import org.whispersystems.textsecuregcm.util.Conversions;
import org.whispersystems.textsecuregcm.util.Util;
import reactor.util.function.Tuple2;
import reactor.util.function.Tuples;
/// An implementation of a message store backed by FoundationDB.
///
@ -55,6 +51,7 @@ public class FoundationDbMessageStore {
private final Database[][] databasesByEpoch;
private final Map<Database, VersionstampClock> versionstampClocks;
private final int[] liveEpochs;
private final int activeEpoch;
private final VersionstampUUIDCipher versionstampUUIDCipher;
private final Clock clock;
@ -96,23 +93,29 @@ public class FoundationDbMessageStore {
boolean present) {
}
public FoundationDbMessageStore(final Map<Integer, List<Database>> databasesByEpoch,
public FoundationDbMessageStore(final Map<Integer, List<Database>> databasesByEpochMap,
final int activeEpoch,
final VersionstampUUIDCipher versionstampUUIDCipher,
final Clock clock) {
final Database[][] databasesByEpochArray = new Database[MAX_EPOCHS][];
databasesByEpoch.forEach((epoch, databases) ->
databasesByEpochMap.forEach((epoch, databases) ->
databasesByEpochArray[epoch] = databases.toArray(Database[]::new));
this.databasesByEpoch = databasesByEpochArray;
this.liveEpochs = IntStream.range(0, MAX_EPOCHS).filter(e -> databasesByEpochArray[e] != null).toArray();
this.activeEpoch = activeEpoch;
this.versionstampUUIDCipher = versionstampUUIDCipher;
this.versionstampClocks = databasesByEpoch.values().stream()
this.versionstampClocks = databasesByEpochMap.values().stream()
.flatMap(List::stream)
.distinct()
.collect(Collectors.toMap(Function.identity(), db -> new VersionstampClock(db, clock)));
.collect(Collectors.toMap(Function.identity(),
db -> new VersionstampClock(db, clock),
(_, _) -> {
throw new AssertionError("Source stream had duplicates after distinct()");
},
IdentityHashMap::new));
this.clock = clock;
}
@ -394,10 +397,8 @@ public class FoundationDbMessageStore {
// For each configured database epoch, which database held (or holds) the messages for this ACI/device pair?
final Database[] databasesForQueueByEpoch = new Database[databasesByEpoch.length];
for (int epoch = 0; epoch < databasesByEpoch.length; epoch++) {
databasesForQueueByEpoch[epoch] = databasesByEpoch[epoch] != null
? getDatabases(epoch)[hashAciToShardNumber(aci, epoch)]
: null;
for (final int epoch : liveEpochs) {
databasesForQueueByEpoch[epoch] = getShardForAci(aci, epoch);
}
return new FoundationDbMessageStream(getDeviceQueueSubspace(aci, destinationDevice.getId()),
@ -411,41 +412,56 @@ public class FoundationDbMessageStore {
/// Record the versionstamp for the current time in each database's versionstamp clock.
public void recordVersionstamps() {
CompletableFuture
.allOf(
versionstampClocks.values().stream()
.map(VersionstampClock::recordVersionstampAndTime)
.toArray(CompletableFuture[]::new))
.join();
for (VersionstampClock versionstampClock : versionstampClocks.values()) {
versionstampClock.recordVersionstampAndTime();
}
}
/// Delete messages for the given devices that were inserted before the given time.
///
/// This issues one range delete for every account/device pair in the map for each actual underlying FoundationDB
/// cluster that ever hosted that account in a configured epoch (not just the current epoch). While range deletes are
/// efficient, they can potentially induce significant load on the storage process, so callers should be judicious
/// with flow control when calling this method.
///
/// @param accountDeviceIdentifiers a map from ACI to the list of device IDs for which expired messages should be
/// trimmed. Note that, depending on the sharding schema of the configured message store, it is possible that we will
/// issue a range clear for every device in a single transaction. One range clear involves two keys, each of which has
/// an ACI, device ID, and versionstamp, along with overhead to identify the relevant subspace; given the 10MB
/// transaction limit and a healthy margin for safety, this means there should be well under 100,000 total
/// account/device pairs supplied in a single call to this method.
///
/// @param cutoffTime the expiration threshold. Messages inserted before this time may be deleted; messages inserted
/// after it will not be. Deletion depends on the underlying [versionstamp clocks][VersionstampClock] being kept up to
/// date.
@VisibleForTesting
CompletableFuture<Void> deleteMessagesBefore(final Map<AciServiceIdentifier, List<Byte>> accountDeviceIdentifiers, final Instant cutoffTime) {
final List<Integer> liveEpochs = IntStream.range(0, MAX_EPOCHS).filter(e -> databasesByEpoch[e] != null).boxed().toList();
final Multimap<Database, Subspace> queueSpacesByDatabase = accountDeviceIdentifiers.entrySet()
.stream()
.flatMap(entry -> liveEpochs.stream().map(e -> getShardForAci(entry.getKey(), e)).distinct().map(db -> Tuples.of(db, entry)))
.collect(
Multimaps.flatteningToMultimap(
Tuple2::getT1,
dbAndAciAndDevices -> {
final AciServiceIdentifier aci = dbAndAciAndDevices.getT2().getKey();
final List<Byte> devices = dbAndAciAndDevices.getT2().getValue();
return devices.stream().map(deviceId -> getDeviceQueueSubspace(aci, deviceId));
},
MultimapBuilder.hashKeys().arrayListValues()::build));
public void deleteMessagesBefore(final Map<AciServiceIdentifier, List<Byte>> accountDeviceIdentifiers, final Instant cutoffTime) {
final Map<Database, List<Subspace>> queueSubspacesToTrimByDatabase = new IdentityHashMap<>();
return CompletableFuture.allOf(
versionstampClocks.entrySet().stream()
.flatMap(dbAndClock -> dbAndClock.getValue().getVersionstamp(cutoffTime).map(v -> Tuples.of(dbAndClock.getKey(), v)).stream())
.map(dbAndVersionstamp -> {
final Database db = dbAndVersionstamp.getT1();
final Versionstamp versionstamp = dbAndVersionstamp.getT2();
return db.runAsync(txn -> {
queueSpacesByDatabase.get(db).forEach(s -> txn.clear(new Range(s.pack(Tuple.from()), s.pack(Tuple.from(versionstamp)))));
return CompletableFuture.<Void>completedFuture(null);
});
}).toArray(CompletableFuture[]::new));
accountDeviceIdentifiers.forEach((aci, deviceIds) -> {
for (final byte deviceId : deviceIds) {
final Subspace queueSubspace = getDeviceQueueSubspace(aci, deviceId);
IntStream.of(liveEpochs)
.mapToObj(e -> getShardForAci(aci, e))
.distinct()
.forEach(database -> queueSubspacesToTrimByDatabase.computeIfAbsent(database, _ -> new ArrayList<>()).add(queueSubspace));
}
});
queueSubspacesToTrimByDatabase.forEach((database, queueSubspaces) -> {
// It's OK that this puts reading the versionstamp in a separate transaction from the deletes; versionstamp clock
// entries are effectively immutable and we're looking for one that was written presumably very far in the past,
// so there's no conflict to avoid
versionstampClocks.get(database).getVersionstamp(cutoffTime).ifPresent(cutoffVersionstamp -> {
database.run(transaction -> {
transaction.options().setPriorityBatch();
for (final Subspace queueSubspace : queueSubspaces) {
transaction.clear(new Range(queueSubspace.getKey(), queueSubspace.pack(Tuple.from(cutoffVersionstamp))));
}
return null;
});
});
});
}
static Versionstamp getVersionstamp(final byte[] messageKey) {

View File

@ -16,7 +16,6 @@ import java.time.Clock;
import java.time.Instant;
import java.util.Iterator;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
/// A "versionstamp clock" records versionstamp/time pairs in a FoundationDB database. The purpose of the "clock" is to
/// allow callers to get the versionstamp of a FoundationDB database at a specific real-world time to facilitate
@ -45,18 +44,19 @@ public class VersionstampClock {
/// Make a recording in the database of the current time and associated versionstamp.
///
/// @return a future for the versionstamp for the newly-recorded entry
public CompletableFuture<Versionstamp> recordVersionstampAndTime() {
/// @return the versionstamp for the newly-recorded entry
public Versionstamp recordVersionstampAndTime() {
final Instant currentTime = clock.instant();
return database.runAsync(transaction -> {
return database.run(transaction -> {
transaction.mutate(MutationType.SET_VERSIONSTAMPED_VALUE,
getTimestampKey(currentTime),
Tuple.from(Versionstamp.incomplete()).packWithVersionstamp());
return CompletableFuture.completedFuture(transaction.getVersionstamp());
return transaction.getVersionstamp();
})
.thenApply(ff -> Versionstamp.complete(ff.join()));
.thenApply(Versionstamp::complete)
.join();
}
/// Returns the highest versionstamp recorded at or before the given instant.

View File

@ -1019,12 +1019,20 @@ class FoundationDbMessageStoreTest {
CLOCK.pin(testTime);
foundationDbMessageStore.deleteMessagesBefore(
acis.stream().collect(Collectors.toMap(Function.identity(), _ -> List.of(Device.PRIMARY_ID))),
threshold).join();
threshold);
// make sure we have new but not old messages
for (AciServiceIdentifier aci : acis) {
final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device);
final List<MessageProtos.Envelope> messages = JdkFlowAdapter.flowPublisherToFlux(messageStream.getMessages()).takeUntil(entry -> entry instanceof MessageStreamEntry.QueueEmpty).filter(entry -> entry instanceof MessageStreamEntry.Envelope).cast(MessageStreamEntry.Envelope.class).map(MessageStreamEntry.Envelope::message).collectList().block();
final List<MessageProtos.Envelope> messages =
JdkFlowAdapter.flowPublisherToFlux(messageStream.getMessages())
.takeUntil(entry -> entry instanceof MessageStreamEntry.QueueEmpty)
.filter(entry -> entry instanceof MessageStreamEntry.Envelope)
.cast(MessageStreamEntry.Envelope.class)
.map(MessageStreamEntry.Envelope::message)
.collectList()
.blockOptional()
.orElseGet(Collections::emptyList);
assertEquals(10, messages.size());
for (MessageProtos.Envelope m : messages) {
assertEquals(newTime.toEpochMilli(), m.getServerTimestamp());

View File

@ -47,10 +47,10 @@ class VersionstampClockTest {
final Instant tomorrow = start.plus(Duration.ofDays(1));
clock.pin(firstInsert);
final Versionstamp firstStamp = versionstampClock.recordVersionstampAndTime().join();
final Versionstamp firstStamp = versionstampClock.recordVersionstampAndTime();
clock.pin(secondInsert);
final Versionstamp secondStamp = versionstampClock.recordVersionstampAndTime().join();
final Versionstamp secondStamp = versionstampClock.recordVersionstampAndTime();
assertThat(versionstampClock.getVersionstamp(start)).isEmpty();
assertThat(versionstampClock.getVersionstamp(firstInsert)).isPresent().hasValue(firstStamp);
@ -66,13 +66,13 @@ class VersionstampClockTest {
final Instant old = start.minus(Duration.ofDays(21));
clock.pin(veryOld);
final Versionstamp veryOldStamp = versionstampClock.recordVersionstampAndTime().join();
final Versionstamp veryOldStamp = versionstampClock.recordVersionstampAndTime();
clock.pin(old);
final Versionstamp oldStamp = versionstampClock.recordVersionstampAndTime().join();
final Versionstamp oldStamp = versionstampClock.recordVersionstampAndTime();
clock.pin(start);
final Versionstamp nowStamp = versionstampClock.recordVersionstampAndTime().join();
final Versionstamp nowStamp = versionstampClock.recordVersionstampAndTime();
assertThat(versionstampClock.getVersionstamp(veryOld)).isPresent().hasValue(veryOldStamp);