From 671a6e1d7c56a43e6b3026e020353243d5f2d73c Mon Sep 17 00:00:00 2001 From: Jonathan Klabunde Tomer <125505367+jkt-signal@users.noreply.github.com> Date: Mon, 22 Jun 2026 11:04:58 -0700 Subject: [PATCH] address review comments Co-authored-by: Jon Chambers <63609320+jon-signal@users.noreply.github.com> --- .../FoundationDbMessageStore.java | 104 ++++++++++-------- .../foundationdb/VersionstampClock.java | 12 +- .../FoundationDbMessageStoreTest.java | 12 +- .../foundationdb/VersionstampClockTest.java | 10 +- 4 files changed, 81 insertions(+), 57 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java index 1bdf59e00..fa9626c1b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java @@ -8,9 +8,6 @@ import com.apple.foundationdb.subspace.Subspace; import com.apple.foundationdb.tuple.Tuple; import com.apple.foundationdb.tuple.Versionstamp; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.Multimap; -import com.google.common.collect.MultimapBuilder; -import com.google.common.collect.Multimaps; import com.google.common.hash.Hashing; import io.dropwizard.util.DataSize; import io.micrometer.core.instrument.Counter; @@ -21,6 +18,7 @@ import java.time.Duration; import java.time.Instant; import java.util.ArrayList; import java.util.HashMap; +import java.util.IdentityHashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -37,8 +35,6 @@ import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessageStream; import org.whispersystems.textsecuregcm.util.Conversions; import org.whispersystems.textsecuregcm.util.Util; -import reactor.util.function.Tuple2; -import reactor.util.function.Tuples; /// An implementation of a message store backed by FoundationDB. /// @@ -55,6 +51,7 @@ public class FoundationDbMessageStore { private final Database[][] databasesByEpoch; private final Map versionstampClocks; + private final int[] liveEpochs; private final int activeEpoch; private final VersionstampUUIDCipher versionstampUUIDCipher; private final Clock clock; @@ -96,23 +93,29 @@ public class FoundationDbMessageStore { boolean present) { } - public FoundationDbMessageStore(final Map> databasesByEpoch, + public FoundationDbMessageStore(final Map> databasesByEpochMap, final int activeEpoch, final VersionstampUUIDCipher versionstampUUIDCipher, final Clock clock) { final Database[][] databasesByEpochArray = new Database[MAX_EPOCHS][]; - databasesByEpoch.forEach((epoch, databases) -> + databasesByEpochMap.forEach((epoch, databases) -> databasesByEpochArray[epoch] = databases.toArray(Database[]::new)); this.databasesByEpoch = databasesByEpochArray; + this.liveEpochs = IntStream.range(0, MAX_EPOCHS).filter(e -> databasesByEpochArray[e] != null).toArray(); this.activeEpoch = activeEpoch; this.versionstampUUIDCipher = versionstampUUIDCipher; - this.versionstampClocks = databasesByEpoch.values().stream() + this.versionstampClocks = databasesByEpochMap.values().stream() .flatMap(List::stream) .distinct() - .collect(Collectors.toMap(Function.identity(), db -> new VersionstampClock(db, clock))); + .collect(Collectors.toMap(Function.identity(), + db -> new VersionstampClock(db, clock), + (_, _) -> { + throw new AssertionError("Source stream had duplicates after distinct()"); + }, + IdentityHashMap::new)); this.clock = clock; } @@ -394,10 +397,8 @@ public class FoundationDbMessageStore { // For each configured database epoch, which database held (or holds) the messages for this ACI/device pair? final Database[] databasesForQueueByEpoch = new Database[databasesByEpoch.length]; - for (int epoch = 0; epoch < databasesByEpoch.length; epoch++) { - databasesForQueueByEpoch[epoch] = databasesByEpoch[epoch] != null - ? getDatabases(epoch)[hashAciToShardNumber(aci, epoch)] - : null; + for (final int epoch : liveEpochs) { + databasesForQueueByEpoch[epoch] = getShardForAci(aci, epoch); } return new FoundationDbMessageStream(getDeviceQueueSubspace(aci, destinationDevice.getId()), @@ -411,41 +412,56 @@ public class FoundationDbMessageStore { /// Record the versionstamp for the current time in each database's versionstamp clock. public void recordVersionstamps() { - CompletableFuture - .allOf( - versionstampClocks.values().stream() - .map(VersionstampClock::recordVersionstampAndTime) - .toArray(CompletableFuture[]::new)) - .join(); + for (VersionstampClock versionstampClock : versionstampClocks.values()) { + versionstampClock.recordVersionstampAndTime(); + } } + /// Delete messages for the given devices that were inserted before the given time. + /// + /// This issues one range delete for every account/device pair in the map for each actual underlying FoundationDB + /// cluster that ever hosted that account in a configured epoch (not just the current epoch). While range deletes are + /// efficient, they can potentially induce significant load on the storage process, so callers should be judicious + /// with flow control when calling this method. + /// + /// @param accountDeviceIdentifiers a map from ACI to the list of device IDs for which expired messages should be + /// trimmed. Note that, depending on the sharding schema of the configured message store, it is possible that we will + /// issue a range clear for every device in a single transaction. One range clear involves two keys, each of which has + /// an ACI, device ID, and versionstamp, along with overhead to identify the relevant subspace; given the 10MB + /// transaction limit and a healthy margin for safety, this means there should be well under 100,000 total + /// account/device pairs supplied in a single call to this method. + /// + /// @param cutoffTime the expiration threshold. Messages inserted before this time may be deleted; messages inserted + /// after it will not be. Deletion depends on the underlying [versionstamp clocks][VersionstampClock] being kept up to + /// date. @VisibleForTesting - CompletableFuture deleteMessagesBefore(final Map> accountDeviceIdentifiers, final Instant cutoffTime) { - final List liveEpochs = IntStream.range(0, MAX_EPOCHS).filter(e -> databasesByEpoch[e] != null).boxed().toList(); - final Multimap queueSpacesByDatabase = accountDeviceIdentifiers.entrySet() - .stream() - .flatMap(entry -> liveEpochs.stream().map(e -> getShardForAci(entry.getKey(), e)).distinct().map(db -> Tuples.of(db, entry))) - .collect( - Multimaps.flatteningToMultimap( - Tuple2::getT1, - dbAndAciAndDevices -> { - final AciServiceIdentifier aci = dbAndAciAndDevices.getT2().getKey(); - final List devices = dbAndAciAndDevices.getT2().getValue(); - return devices.stream().map(deviceId -> getDeviceQueueSubspace(aci, deviceId)); - }, - MultimapBuilder.hashKeys().arrayListValues()::build)); + public void deleteMessagesBefore(final Map> accountDeviceIdentifiers, final Instant cutoffTime) { + final Map> queueSubspacesToTrimByDatabase = new IdentityHashMap<>(); - return CompletableFuture.allOf( - versionstampClocks.entrySet().stream() - .flatMap(dbAndClock -> dbAndClock.getValue().getVersionstamp(cutoffTime).map(v -> Tuples.of(dbAndClock.getKey(), v)).stream()) - .map(dbAndVersionstamp -> { - final Database db = dbAndVersionstamp.getT1(); - final Versionstamp versionstamp = dbAndVersionstamp.getT2(); - return db.runAsync(txn -> { - queueSpacesByDatabase.get(db).forEach(s -> txn.clear(new Range(s.pack(Tuple.from()), s.pack(Tuple.from(versionstamp))))); - return CompletableFuture.completedFuture(null); - }); - }).toArray(CompletableFuture[]::new)); + accountDeviceIdentifiers.forEach((aci, deviceIds) -> { + for (final byte deviceId : deviceIds) { + final Subspace queueSubspace = getDeviceQueueSubspace(aci, deviceId); + IntStream.of(liveEpochs) + .mapToObj(e -> getShardForAci(aci, e)) + .distinct() + .forEach(database -> queueSubspacesToTrimByDatabase.computeIfAbsent(database, _ -> new ArrayList<>()).add(queueSubspace)); + } + }); + + queueSubspacesToTrimByDatabase.forEach((database, queueSubspaces) -> { + // It's OK that this puts reading the versionstamp in a separate transaction from the deletes; versionstamp clock + // entries are effectively immutable and we're looking for one that was written presumably very far in the past, + // so there's no conflict to avoid + versionstampClocks.get(database).getVersionstamp(cutoffTime).ifPresent(cutoffVersionstamp -> { + database.run(transaction -> { + transaction.options().setPriorityBatch(); + for (final Subspace queueSubspace : queueSubspaces) { + transaction.clear(new Range(queueSubspace.getKey(), queueSubspace.pack(Tuple.from(cutoffVersionstamp)))); + } + return null; + }); + }); + }); } static Versionstamp getVersionstamp(final byte[] messageKey) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/VersionstampClock.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/VersionstampClock.java index e66e92ab9..ee31366e3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/VersionstampClock.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/VersionstampClock.java @@ -16,7 +16,6 @@ import java.time.Clock; import java.time.Instant; import java.util.Iterator; import java.util.Optional; -import java.util.concurrent.CompletableFuture; /// A "versionstamp clock" records versionstamp/time pairs in a FoundationDB database. The purpose of the "clock" is to /// allow callers to get the versionstamp of a FoundationDB database at a specific real-world time to facilitate @@ -45,18 +44,19 @@ public class VersionstampClock { /// Make a recording in the database of the current time and associated versionstamp. /// - /// @return a future for the versionstamp for the newly-recorded entry - public CompletableFuture recordVersionstampAndTime() { + /// @return the versionstamp for the newly-recorded entry + public Versionstamp recordVersionstampAndTime() { final Instant currentTime = clock.instant(); - return database.runAsync(transaction -> { + return database.run(transaction -> { transaction.mutate(MutationType.SET_VERSIONSTAMPED_VALUE, getTimestampKey(currentTime), Tuple.from(Versionstamp.incomplete()).packWithVersionstamp()); - return CompletableFuture.completedFuture(transaction.getVersionstamp()); + return transaction.getVersionstamp(); }) - .thenApply(ff -> Versionstamp.complete(ff.join())); + .thenApply(Versionstamp::complete) + .join(); } /// Returns the highest versionstamp recorded at or before the given instant. diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java index b6286842a..949467a38 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java @@ -1019,12 +1019,20 @@ class FoundationDbMessageStoreTest { CLOCK.pin(testTime); foundationDbMessageStore.deleteMessagesBefore( acis.stream().collect(Collectors.toMap(Function.identity(), _ -> List.of(Device.PRIMARY_ID))), - threshold).join(); + threshold); // make sure we have new but not old messages for (AciServiceIdentifier aci : acis) { final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device); - final List messages = JdkFlowAdapter.flowPublisherToFlux(messageStream.getMessages()).takeUntil(entry -> entry instanceof MessageStreamEntry.QueueEmpty).filter(entry -> entry instanceof MessageStreamEntry.Envelope).cast(MessageStreamEntry.Envelope.class).map(MessageStreamEntry.Envelope::message).collectList().block(); + final List messages = + JdkFlowAdapter.flowPublisherToFlux(messageStream.getMessages()) + .takeUntil(entry -> entry instanceof MessageStreamEntry.QueueEmpty) + .filter(entry -> entry instanceof MessageStreamEntry.Envelope) + .cast(MessageStreamEntry.Envelope.class) + .map(MessageStreamEntry.Envelope::message) + .collectList() + .blockOptional() + .orElseGet(Collections::emptyList); assertEquals(10, messages.size()); for (MessageProtos.Envelope m : messages) { assertEquals(newTime.toEpochMilli(), m.getServerTimestamp()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/VersionstampClockTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/VersionstampClockTest.java index b30858145..20338d3d2 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/VersionstampClockTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/VersionstampClockTest.java @@ -47,10 +47,10 @@ class VersionstampClockTest { final Instant tomorrow = start.plus(Duration.ofDays(1)); clock.pin(firstInsert); - final Versionstamp firstStamp = versionstampClock.recordVersionstampAndTime().join(); + final Versionstamp firstStamp = versionstampClock.recordVersionstampAndTime(); clock.pin(secondInsert); - final Versionstamp secondStamp = versionstampClock.recordVersionstampAndTime().join(); + final Versionstamp secondStamp = versionstampClock.recordVersionstampAndTime(); assertThat(versionstampClock.getVersionstamp(start)).isEmpty(); assertThat(versionstampClock.getVersionstamp(firstInsert)).isPresent().hasValue(firstStamp); @@ -66,13 +66,13 @@ class VersionstampClockTest { final Instant old = start.minus(Duration.ofDays(21)); clock.pin(veryOld); - final Versionstamp veryOldStamp = versionstampClock.recordVersionstampAndTime().join(); + final Versionstamp veryOldStamp = versionstampClock.recordVersionstampAndTime(); clock.pin(old); - final Versionstamp oldStamp = versionstampClock.recordVersionstampAndTime().join(); + final Versionstamp oldStamp = versionstampClock.recordVersionstampAndTime(); clock.pin(start); - final Versionstamp nowStamp = versionstampClock.recordVersionstampAndTime().join(); + final Versionstamp nowStamp = versionstampClock.recordVersionstampAndTime(); assertThat(versionstampClock.getVersionstamp(veryOld)).isPresent().hasValue(veryOldStamp);