add FoundationDbMessageStore api to clear messages before a given time

This commit is contained in:
Jonathan Klabunde Tomer 2026-06-18 17:12:43 -07:00 committed by Chris Eager
parent 78b3147491
commit f4e16676c9
5 changed files with 118 additions and 18 deletions

View File

@ -8,7 +8,6 @@ package org.whispersystems.textsecuregcm.storage;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Flow;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
/// A message stream publishes an ordered stream of Signal messages from a destination device's queue and provides a
/// mechanism for consumers to acknowledge receipt of delivered messages.

View File

@ -2,15 +2,23 @@ package org.whispersystems.textsecuregcm.storage.foundationdb;
import com.apple.foundationdb.Database;
import com.apple.foundationdb.MutationType;
import com.apple.foundationdb.Range;
import com.apple.foundationdb.Transaction;
import com.apple.foundationdb.subspace.Subspace;
import com.apple.foundationdb.tuple.Tuple;
import com.apple.foundationdb.tuple.Versionstamp;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Multimap;
import com.google.common.collect.MultimapBuilder;
import com.google.common.collect.Multimaps;
import com.google.common.hash.Hashing;
import io.dropwizard.util.DataSize;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
@ -22,9 +30,6 @@ import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
@ -32,6 +37,8 @@ import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessageStream;
import org.whispersystems.textsecuregcm.util.Conversions;
import org.whispersystems.textsecuregcm.util.Util;
import reactor.util.function.Tuple2;
import reactor.util.function.Tuples;
/// An implementation of a message store backed by FoundationDB.
///
@ -47,6 +54,7 @@ import org.whispersystems.textsecuregcm.util.Util;
public class FoundationDbMessageStore {
private final Database[][] databasesByEpoch;
private final Map<Database, VersionstampClock> versionstampClocks;
private final int activeEpoch;
private final VersionstampUUIDCipher versionstampUUIDCipher;
private final Clock clock;
@ -101,6 +109,10 @@ public class FoundationDbMessageStore {
this.databasesByEpoch = databasesByEpochArray;
this.activeEpoch = activeEpoch;
this.versionstampUUIDCipher = versionstampUUIDCipher;
this.versionstampClocks = databasesByEpoch.values().stream()
.flatMap(List::stream)
.distinct()
.collect(Collectors.toMap(Function.identity(), db -> new VersionstampClock(db, clock)));
this.clock = clock;
}
@ -397,6 +409,45 @@ public class FoundationDbMessageStore {
doAfterCleanup);
}
/// Record the versionstamp for the current time in each database's versionstamp clock.
public void recordVersionstamps() {
CompletableFuture
.allOf(
versionstampClocks.values().stream()
.map(VersionstampClock::recordVersionstampAndTime)
.toArray(CompletableFuture[]::new))
.join();
}
@VisibleForTesting
CompletableFuture<Void> deleteMessagesBefore(final Map<AciServiceIdentifier, List<Byte>> accountDeviceIdentifiers, final Instant cutoffTime) {
final List<Integer> liveEpochs = IntStream.range(0, MAX_EPOCHS).filter(e -> databasesByEpoch[e] != null).boxed().toList();
final Multimap<Database, Subspace> queueSpacesByDatabase = accountDeviceIdentifiers.entrySet()
.stream()
.flatMap(entry -> liveEpochs.stream().map(e -> getShardForAci(entry.getKey(), e)).distinct().map(db -> Tuples.of(db, entry)))
.collect(
Multimaps.flatteningToMultimap(
Tuple2::getT1,
dbAndAciAndDevices -> {
final AciServiceIdentifier aci = dbAndAciAndDevices.getT2().getKey();
final List<Byte> devices = dbAndAciAndDevices.getT2().getValue();
return devices.stream().map(deviceId -> getDeviceQueueSubspace(aci, deviceId));
},
MultimapBuilder.hashKeys().arrayListValues()::build));
return CompletableFuture.allOf(
versionstampClocks.entrySet().stream()
.flatMap(dbAndClock -> dbAndClock.getValue().getVersionstamp(cutoffTime).map(v -> Tuples.of(dbAndClock.getKey(), v)).stream())
.map(dbAndVersionstamp -> {
final Database db = dbAndVersionstamp.getT1();
final Versionstamp versionstamp = dbAndVersionstamp.getT2();
return db.runAsync(txn -> {
queueSpacesByDatabase.get(db).forEach(s -> txn.clear(new Range(s.pack(Tuple.from()), s.pack(Tuple.from(versionstamp)))));
return CompletableFuture.<Void>completedFuture(null);
});
}).toArray(CompletableFuture[]::new));
}
static Versionstamp getVersionstamp(final byte[] messageKey) {
return Tuple.fromBytes(messageKey).getVersionstamp(4);
}

View File

@ -16,6 +16,7 @@ import java.time.Clock;
import java.time.Instant;
import java.util.Iterator;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
/// A "versionstamp clock" records versionstamp/time pairs in a FoundationDB database. The purpose of the "clock" is to
/// allow callers to get the versionstamp of a FoundationDB database at a specific real-world time to facilitate
@ -44,19 +45,18 @@ public class VersionstampClock {
/// Make a recording in the database of the current time and associated versionstamp.
///
/// @return the versionstamp for the newly-recorded entry
public Versionstamp recordVersionstampAndTime() {
/// @return a future for the versionstamp for the newly-recorded entry
public CompletableFuture<Versionstamp> recordVersionstampAndTime() {
final Instant currentTime = clock.instant();
return database.run(transaction -> {
return database.runAsync(transaction -> {
transaction.mutate(MutationType.SET_VERSIONSTAMPED_VALUE,
getTimestampKey(currentTime),
Tuple.from(Versionstamp.incomplete()).packWithVersionstamp());
return transaction.getVersionstamp();
return CompletableFuture.completedFuture(transaction.getVersionstamp());
})
.thenApply(Versionstamp::complete)
.join();
.thenApply(ff -> Versionstamp.complete(ff.join()));
}
/// Returns the highest versionstamp recorded at or before the given instant.

View File

@ -19,10 +19,8 @@ import com.google.protobuf.InvalidProtocolBufferException;
import io.dropwizard.util.DataSize;
import java.io.UncheckedIOException;
import java.security.SecureRandom;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@ -62,6 +60,7 @@ import org.whispersystems.textsecuregcm.storage.FoundationDbClusterExtension;
import org.whispersystems.textsecuregcm.storage.MessageStream;
import org.whispersystems.textsecuregcm.storage.MessageStreamEntry;
import org.whispersystems.textsecuregcm.util.Conversions;
import org.whispersystems.textsecuregcm.util.TestClock;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.textsecuregcm.util.Util;
@ -77,7 +76,7 @@ class FoundationDbMessageStoreTest {
private VersionstampUUIDCipher versionstampUUIDCipher;
private FoundationDbMessageStore foundationDbMessageStore;
private static final Clock CLOCK = Clock.fixed(Instant.ofEpochSecond(500), ZoneId.of("UTC"));
private static final TestClock CLOCK = TestClock.pinned(Instant.ofEpochSecond(500));
private static final int DEFAULT_EPOCH = 0;
private static final int FUTURE_EPOCH = 2;
@ -982,6 +981,57 @@ class FoundationDbMessageStoreTest {
assertEquals(shardId, FoundationDbMessageStore.getShardId(versionstamp));
}
@Test
void deleteExpiredMessages() {
final Instant oldTime = CLOCK.instant();
final Instant threshold = oldTime.plus(Duration.ofDays(1));
final Instant newTime = oldTime.plus(Duration.ofDays(2));
final Instant testTime = oldTime.plus(Duration.ofDays(3));
final List<AciServiceIdentifier> acis = IntStream.range(0, 1024).mapToObj(_ -> new AciServiceIdentifier(UUID.randomUUID())).toList();
final Device device = new Device();
device.setId(Device.PRIMARY_ID);
// insert some old messages
CompletableFuture.allOf(
IntStream.range(0, 10)
.mapToObj(_ -> foundationDbMessageStore.insert(
acis.stream()
.collect(Collectors.toMap(Function.identity(), _ -> Map.of(Device.PRIMARY_ID, generateRandomMessage(false))))))
.toArray(CompletableFuture[]::new))
.join();
// update the versionstamp at the cutoff threshold
CLOCK.pin(threshold);
foundationDbMessageStore.recordVersionstamps();
// insert some new messages
CLOCK.pin(newTime);
CompletableFuture.allOf(
IntStream.range(0, 10)
.mapToObj(_ -> foundationDbMessageStore.insert(
acis.stream()
.collect(Collectors.toMap(Function.identity(), _ -> Map.of(Device.PRIMARY_ID, generateRandomMessage(false))))))
.toArray(CompletableFuture[]::new))
.join();
// advance to a future date, and clear messages before the threshold
CLOCK.pin(testTime);
foundationDbMessageStore.deleteMessagesBefore(
acis.stream().collect(Collectors.toMap(Function.identity(), _ -> List.of(Device.PRIMARY_ID))),
threshold).join();
// make sure we have new but not old messages
for (AciServiceIdentifier aci : acis) {
final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device);
final List<MessageProtos.Envelope> messages = JdkFlowAdapter.flowPublisherToFlux(messageStream.getMessages()).takeUntil(entry -> entry instanceof MessageStreamEntry.QueueEmpty).filter(entry -> entry instanceof MessageStreamEntry.Envelope).cast(MessageStreamEntry.Envelope.class).map(MessageStreamEntry.Envelope::message).collectList().block();
assertEquals(10, messages.size());
for (MessageProtos.Envelope m : messages) {
assertEquals(newTime.toEpochMilli(), m.getServerTimestamp());
}
}
}
static MessageProtos.Envelope generateRandomMessage(final boolean ephemeral, final int contentSize) {
return generateRandomMessage(ephemeral, contentSize, CLOCK.millis());
}

View File

@ -47,10 +47,10 @@ class VersionstampClockTest {
final Instant tomorrow = start.plus(Duration.ofDays(1));
clock.pin(firstInsert);
final Versionstamp firstStamp = versionstampClock.recordVersionstampAndTime();
final Versionstamp firstStamp = versionstampClock.recordVersionstampAndTime().join();
clock.pin(secondInsert);
final Versionstamp secondStamp = versionstampClock.recordVersionstampAndTime();
final Versionstamp secondStamp = versionstampClock.recordVersionstampAndTime().join();
assertThat(versionstampClock.getVersionstamp(start)).isEmpty();
assertThat(versionstampClock.getVersionstamp(firstInsert)).isPresent().hasValue(firstStamp);
@ -66,13 +66,13 @@ class VersionstampClockTest {
final Instant old = start.minus(Duration.ofDays(21));
clock.pin(veryOld);
final Versionstamp veryOldStamp = versionstampClock.recordVersionstampAndTime();
final Versionstamp veryOldStamp = versionstampClock.recordVersionstampAndTime().join();
clock.pin(old);
final Versionstamp oldStamp = versionstampClock.recordVersionstampAndTime();
final Versionstamp oldStamp = versionstampClock.recordVersionstampAndTime().join();
clock.pin(start);
final Versionstamp nowStamp = versionstampClock.recordVersionstampAndTime();
final Versionstamp nowStamp = versionstampClock.recordVersionstampAndTime().join();
assertThat(versionstampClock.getVersionstamp(veryOld)).isPresent().hasValue(veryOldStamp);