From f4e16676c9e6ec2c182b1f6aeb76b0bc20d55f71 Mon Sep 17 00:00:00 2001 From: Jonathan Klabunde Tomer Date: Thu, 18 Jun 2026 17:12:43 -0700 Subject: [PATCH] add FoundationDbMessageStore api to clear messages before a given time --- .../textsecuregcm/storage/MessageStream.java | 1 - .../FoundationDbMessageStore.java | 57 ++++++++++++++++++- .../foundationdb/VersionstampClock.java | 12 ++-- .../FoundationDbMessageStoreTest.java | 56 +++++++++++++++++- .../foundationdb/VersionstampClockTest.java | 10 ++-- 5 files changed, 118 insertions(+), 18 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessageStream.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessageStream.java index 38f7f8959..abb665e45 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessageStream.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessageStream.java @@ -8,7 +8,6 @@ package org.whispersystems.textsecuregcm.storage; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Flow; -import org.whispersystems.textsecuregcm.entities.MessageProtos; /// A message stream publishes an ordered stream of Signal messages from a destination device's queue and provides a /// mechanism for consumers to acknowledge receipt of delivered messages. diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java index d59cf7949..1bdf59e00 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java @@ -2,15 +2,23 @@ package org.whispersystems.textsecuregcm.storage.foundationdb; import com.apple.foundationdb.Database; import com.apple.foundationdb.MutationType; +import com.apple.foundationdb.Range; import com.apple.foundationdb.Transaction; import com.apple.foundationdb.subspace.Subspace; import com.apple.foundationdb.tuple.Tuple; import com.apple.foundationdb.tuple.Versionstamp; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Multimap; +import com.google.common.collect.MultimapBuilder; +import com.google.common.collect.Multimaps; import com.google.common.hash.Hashing; import io.dropwizard.util.DataSize; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Timer; import java.time.Clock; import java.time.Duration; +import java.time.Instant; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -22,9 +30,6 @@ import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.IntStream; -import io.micrometer.core.instrument.Counter; -import io.micrometer.core.instrument.Metrics; -import io.micrometer.core.instrument.Timer; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; @@ -32,6 +37,8 @@ import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessageStream; import org.whispersystems.textsecuregcm.util.Conversions; import org.whispersystems.textsecuregcm.util.Util; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; /// An implementation of a message store backed by FoundationDB. /// @@ -47,6 +54,7 @@ import org.whispersystems.textsecuregcm.util.Util; public class FoundationDbMessageStore { private final Database[][] databasesByEpoch; + private final Map versionstampClocks; private final int activeEpoch; private final VersionstampUUIDCipher versionstampUUIDCipher; private final Clock clock; @@ -101,6 +109,10 @@ public class FoundationDbMessageStore { this.databasesByEpoch = databasesByEpochArray; this.activeEpoch = activeEpoch; this.versionstampUUIDCipher = versionstampUUIDCipher; + this.versionstampClocks = databasesByEpoch.values().stream() + .flatMap(List::stream) + .distinct() + .collect(Collectors.toMap(Function.identity(), db -> new VersionstampClock(db, clock))); this.clock = clock; } @@ -397,6 +409,45 @@ public class FoundationDbMessageStore { doAfterCleanup); } + /// Record the versionstamp for the current time in each database's versionstamp clock. + public void recordVersionstamps() { + CompletableFuture + .allOf( + versionstampClocks.values().stream() + .map(VersionstampClock::recordVersionstampAndTime) + .toArray(CompletableFuture[]::new)) + .join(); + } + + @VisibleForTesting + CompletableFuture deleteMessagesBefore(final Map> accountDeviceIdentifiers, final Instant cutoffTime) { + final List liveEpochs = IntStream.range(0, MAX_EPOCHS).filter(e -> databasesByEpoch[e] != null).boxed().toList(); + final Multimap queueSpacesByDatabase = accountDeviceIdentifiers.entrySet() + .stream() + .flatMap(entry -> liveEpochs.stream().map(e -> getShardForAci(entry.getKey(), e)).distinct().map(db -> Tuples.of(db, entry))) + .collect( + Multimaps.flatteningToMultimap( + Tuple2::getT1, + dbAndAciAndDevices -> { + final AciServiceIdentifier aci = dbAndAciAndDevices.getT2().getKey(); + final List devices = dbAndAciAndDevices.getT2().getValue(); + return devices.stream().map(deviceId -> getDeviceQueueSubspace(aci, deviceId)); + }, + MultimapBuilder.hashKeys().arrayListValues()::build)); + + return CompletableFuture.allOf( + versionstampClocks.entrySet().stream() + .flatMap(dbAndClock -> dbAndClock.getValue().getVersionstamp(cutoffTime).map(v -> Tuples.of(dbAndClock.getKey(), v)).stream()) + .map(dbAndVersionstamp -> { + final Database db = dbAndVersionstamp.getT1(); + final Versionstamp versionstamp = dbAndVersionstamp.getT2(); + return db.runAsync(txn -> { + queueSpacesByDatabase.get(db).forEach(s -> txn.clear(new Range(s.pack(Tuple.from()), s.pack(Tuple.from(versionstamp))))); + return CompletableFuture.completedFuture(null); + }); + }).toArray(CompletableFuture[]::new)); + } + static Versionstamp getVersionstamp(final byte[] messageKey) { return Tuple.fromBytes(messageKey).getVersionstamp(4); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/VersionstampClock.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/VersionstampClock.java index ee31366e3..e66e92ab9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/VersionstampClock.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/VersionstampClock.java @@ -16,6 +16,7 @@ import java.time.Clock; import java.time.Instant; import java.util.Iterator; import java.util.Optional; +import java.util.concurrent.CompletableFuture; /// A "versionstamp clock" records versionstamp/time pairs in a FoundationDB database. The purpose of the "clock" is to /// allow callers to get the versionstamp of a FoundationDB database at a specific real-world time to facilitate @@ -44,19 +45,18 @@ public class VersionstampClock { /// Make a recording in the database of the current time and associated versionstamp. /// - /// @return the versionstamp for the newly-recorded entry - public Versionstamp recordVersionstampAndTime() { + /// @return a future for the versionstamp for the newly-recorded entry + public CompletableFuture recordVersionstampAndTime() { final Instant currentTime = clock.instant(); - return database.run(transaction -> { + return database.runAsync(transaction -> { transaction.mutate(MutationType.SET_VERSIONSTAMPED_VALUE, getTimestampKey(currentTime), Tuple.from(Versionstamp.incomplete()).packWithVersionstamp()); - return transaction.getVersionstamp(); + return CompletableFuture.completedFuture(transaction.getVersionstamp()); }) - .thenApply(Versionstamp::complete) - .join(); + .thenApply(ff -> Versionstamp.complete(ff.join())); } /// Returns the highest versionstamp recorded at or before the given instant. diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java index 1d7303ffe..b6286842a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java @@ -19,10 +19,8 @@ import com.google.protobuf.InvalidProtocolBufferException; import io.dropwizard.util.DataSize; import java.io.UncheckedIOException; import java.security.SecureRandom; -import java.time.Clock; import java.time.Duration; import java.time.Instant; -import java.time.ZoneId; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -62,6 +60,7 @@ import org.whispersystems.textsecuregcm.storage.FoundationDbClusterExtension; import org.whispersystems.textsecuregcm.storage.MessageStream; import org.whispersystems.textsecuregcm.storage.MessageStreamEntry; import org.whispersystems.textsecuregcm.util.Conversions; +import org.whispersystems.textsecuregcm.util.TestClock; import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.UUIDUtil; import org.whispersystems.textsecuregcm.util.Util; @@ -77,7 +76,7 @@ class FoundationDbMessageStoreTest { private VersionstampUUIDCipher versionstampUUIDCipher; private FoundationDbMessageStore foundationDbMessageStore; - private static final Clock CLOCK = Clock.fixed(Instant.ofEpochSecond(500), ZoneId.of("UTC")); + private static final TestClock CLOCK = TestClock.pinned(Instant.ofEpochSecond(500)); private static final int DEFAULT_EPOCH = 0; private static final int FUTURE_EPOCH = 2; @@ -982,6 +981,57 @@ class FoundationDbMessageStoreTest { assertEquals(shardId, FoundationDbMessageStore.getShardId(versionstamp)); } + @Test + void deleteExpiredMessages() { + final Instant oldTime = CLOCK.instant(); + final Instant threshold = oldTime.plus(Duration.ofDays(1)); + final Instant newTime = oldTime.plus(Duration.ofDays(2)); + final Instant testTime = oldTime.plus(Duration.ofDays(3)); + + final List acis = IntStream.range(0, 1024).mapToObj(_ -> new AciServiceIdentifier(UUID.randomUUID())).toList(); + final Device device = new Device(); + device.setId(Device.PRIMARY_ID); + + // insert some old messages + CompletableFuture.allOf( + IntStream.range(0, 10) + .mapToObj(_ -> foundationDbMessageStore.insert( + acis.stream() + .collect(Collectors.toMap(Function.identity(), _ -> Map.of(Device.PRIMARY_ID, generateRandomMessage(false)))))) + .toArray(CompletableFuture[]::new)) + .join(); + + // update the versionstamp at the cutoff threshold + CLOCK.pin(threshold); + foundationDbMessageStore.recordVersionstamps(); + + // insert some new messages + CLOCK.pin(newTime); + CompletableFuture.allOf( + IntStream.range(0, 10) + .mapToObj(_ -> foundationDbMessageStore.insert( + acis.stream() + .collect(Collectors.toMap(Function.identity(), _ -> Map.of(Device.PRIMARY_ID, generateRandomMessage(false)))))) + .toArray(CompletableFuture[]::new)) + .join(); + + // advance to a future date, and clear messages before the threshold + CLOCK.pin(testTime); + foundationDbMessageStore.deleteMessagesBefore( + acis.stream().collect(Collectors.toMap(Function.identity(), _ -> List.of(Device.PRIMARY_ID))), + threshold).join(); + + // make sure we have new but not old messages + for (AciServiceIdentifier aci : acis) { + final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device); + final List messages = JdkFlowAdapter.flowPublisherToFlux(messageStream.getMessages()).takeUntil(entry -> entry instanceof MessageStreamEntry.QueueEmpty).filter(entry -> entry instanceof MessageStreamEntry.Envelope).cast(MessageStreamEntry.Envelope.class).map(MessageStreamEntry.Envelope::message).collectList().block(); + assertEquals(10, messages.size()); + for (MessageProtos.Envelope m : messages) { + assertEquals(newTime.toEpochMilli(), m.getServerTimestamp()); + } + } + } + static MessageProtos.Envelope generateRandomMessage(final boolean ephemeral, final int contentSize) { return generateRandomMessage(ephemeral, contentSize, CLOCK.millis()); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/VersionstampClockTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/VersionstampClockTest.java index 20338d3d2..b30858145 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/VersionstampClockTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/VersionstampClockTest.java @@ -47,10 +47,10 @@ class VersionstampClockTest { final Instant tomorrow = start.plus(Duration.ofDays(1)); clock.pin(firstInsert); - final Versionstamp firstStamp = versionstampClock.recordVersionstampAndTime(); + final Versionstamp firstStamp = versionstampClock.recordVersionstampAndTime().join(); clock.pin(secondInsert); - final Versionstamp secondStamp = versionstampClock.recordVersionstampAndTime(); + final Versionstamp secondStamp = versionstampClock.recordVersionstampAndTime().join(); assertThat(versionstampClock.getVersionstamp(start)).isEmpty(); assertThat(versionstampClock.getVersionstamp(firstInsert)).isPresent().hasValue(firstStamp); @@ -66,13 +66,13 @@ class VersionstampClockTest { final Instant old = start.minus(Duration.ofDays(21)); clock.pin(veryOld); - final Versionstamp veryOldStamp = versionstampClock.recordVersionstampAndTime(); + final Versionstamp veryOldStamp = versionstampClock.recordVersionstampAndTime().join(); clock.pin(old); - final Versionstamp oldStamp = versionstampClock.recordVersionstampAndTime(); + final Versionstamp oldStamp = versionstampClock.recordVersionstampAndTime().join(); clock.pin(start); - final Versionstamp nowStamp = versionstampClock.recordVersionstampAndTime(); + final Versionstamp nowStamp = versionstampClock.recordVersionstampAndTime().join(); assertThat(versionstampClock.getVersionstamp(veryOld)).isPresent().hasValue(veryOldStamp);