Add byte-based rate-limit to attachments

This commit is contained in:
Ravi Khadiwala 2026-03-29 23:53:10 -05:00 committed by ravi-signal
parent f9d3cd8d82
commit 0ee06d83b7
7 changed files with 179 additions and 19 deletions

View File

@ -44,7 +44,8 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
public class AttachmentControllerV4 {
private final ExperimentEnrollmentManager experimentEnrollmentManager;
private final RateLimiter rateLimiter;
private final RateLimiter countRateLimiter;
private final RateLimiter bytesRateLimiter;
private final long maxUploadLength;
private final Map<Integer, AttachmentGenerator> attachmentGenerators;
@ -58,7 +59,8 @@ public class AttachmentControllerV4 {
final TusAttachmentGenerator tusAttachmentGenerator,
final ExperimentEnrollmentManager experimentEnrollmentManager,
final long maxUploadLength) {
this.rateLimiter = rateLimiters.getAttachmentLimiter();
this.countRateLimiter = rateLimiters.getAttachmentLimiter();
this.bytesRateLimiter = rateLimiters.getAttachmentBytesLimiter();
this.experimentEnrollmentManager = experimentEnrollmentManager;
this.maxUploadLength = maxUploadLength;
this.secureRandom = new SecureRandom();
@ -91,11 +93,23 @@ public class AttachmentControllerV4 {
@Parameter(description = "The size of the attachment to upload in bytes")
@QueryParam("uploadLength") final @Valid Optional<@Positive Long> maybeUploadLength)
throws RateLimitExceededException {
final long uploadLength = maybeUploadLength.orElse(maxUploadLength);
if (uploadLength > maxUploadLength) {
throw new ClientErrorException("exceeded maximum uploadLength", Response.Status.REQUEST_ENTITY_TOO_LARGE);
}
rateLimiter.validate(auth.accountIdentifier());
countRateLimiter.validate(auth.accountIdentifier());
if (maybeUploadLength.isPresent()) {
// Ideally we'd check these two rate limits transactionally and only update them if both permits were acquired.
// However, just undoing the first modification if the second one fails is close enough for our purposes
try {
bytesRateLimiter.validate(auth.accountIdentifier(), maybeUploadLength.get());
} catch (RateLimitExceededException e) {
countRateLimiter.restorePermits(auth.accountIdentifier(), 1);
throw e;
}
}
final String key = AttachmentUtil.generateAttachmentKey(secureRandom);
final boolean useCdn3 = this.experimentEnrollmentManager.isEnrolled(auth.accountIdentifier(), AttachmentUtil.CDN3_EXPERIMENT_NAME);

View File

@ -26,7 +26,8 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
public class AttachmentsGrpcService extends SimpleAttachmentsGrpc.AttachmentsImplBase {
private final ExperimentEnrollmentManager experimentEnrollmentManager;
private final RateLimiter rateLimiter;
private final RateLimiter countRateLimiter;
private final RateLimiter bytesRateLimiter;
private final long maxUploadLength;
private final Map<Integer, AttachmentGenerator> attachmentGenerators;
private final SecureRandom secureRandom;
@ -38,7 +39,8 @@ public class AttachmentsGrpcService extends SimpleAttachmentsGrpc.AttachmentsImp
final TusAttachmentGenerator tusAttachmentGenerator,
final long maxUploadLength) {
this.experimentEnrollmentManager = experimentEnrollmentManager;
this.rateLimiter = rateLimiters.getAttachmentLimiter();
this.countRateLimiter = rateLimiters.getAttachmentLimiter();
this.bytesRateLimiter = rateLimiters.getAttachmentBytesLimiter();
this.maxUploadLength = maxUploadLength;
this.secureRandom = new SecureRandom();
this.attachmentGenerators = Map.of(
@ -54,7 +56,16 @@ public class AttachmentsGrpcService extends SimpleAttachmentsGrpc.AttachmentsImp
.build();
}
final AuthenticatedDevice auth = AuthenticationUtil.requireAuthenticatedDevice();
rateLimiter.validate(auth.accountIdentifier());
countRateLimiter.validate(auth.accountIdentifier());
try {
// Ideally we'd check these two rate limits transactionally and only update them if both permits were acquired.
// However, just undoing the first modification if the second one fails is close enough for our purposes
bytesRateLimiter.validate(auth.accountIdentifier(), request.getUploadLength());
} catch (RateLimitExceededException e) {
countRateLimiter.restorePermits(auth.accountIdentifier(), 1);
throw e;
}
final String key = AttachmentUtil.generateAttachmentKey(secureRandom);
final boolean useCdn3 = this.experimentEnrollmentManager.isEnrolled(auth.accountIdentifier(),

View File

@ -8,7 +8,6 @@ package org.whispersystems.textsecuregcm.limits;
import java.util.UUID;
import java.util.concurrent.CompletionStage;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import reactor.core.publisher.Mono;
public interface RateLimiter {
@ -78,4 +77,27 @@ public interface RateLimiter {
default CompletionStage<Void> clearAsync(final UUID accountUuid) {
return clearAsync(accountUuid.toString());
}
/// Restore `permitsToRestore` permits to the pool. If the RateLimiter has a maximum bucket size, it is respected even
/// if `permitsToRestore` would exceed the maximum bucket size.
///
/// @param key The key to restore permits for.
/// @param permitsToRestore The number of permits to restore
/// @implNote The default implementation of this method assumes permits can be restored by calling `validate` with a
/// negative permit count. If the `validate` implementation does not support this, the implementor must provide a
/// custom implementation of [this#restorePermits].
default void restorePermits(final String key, int permitsToRestore) {
if (permitsToRestore < 0) {
throw new IllegalArgumentException("permits to restore must be non-negative");
}
try {
validate(key, -1 * permitsToRestore);
} catch (RateLimitExceededException e) {
throw new IllegalStateException("Out of permits when trying to restore permits", e);
}
}
default void restorePermits(final UUID accountUuid, int permitsToRestore) {
restorePermits(accountUuid.toString(), permitsToRestore);
}
}

View File

@ -8,6 +8,7 @@ import com.google.common.annotations.VisibleForTesting;
import java.time.Clock;
import java.time.Duration;
import java.util.concurrent.ScheduledExecutorService;
import io.dropwizard.util.DataSize;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
@ -19,6 +20,7 @@ public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
BACKUP_AUTH_CHECK("backupAuthCheck", new RateLimiterConfig(100, Duration.ofMinutes(15), true)),
PIN("pin", new RateLimiterConfig(10, Duration.ofDays(1), false)),
ATTACHMENT("attachmentCreate", new RateLimiterConfig(50, Duration.ofMillis(1200), true)),
ATTACHMENT_BYTES("attachmentCreateBytes", new RateLimiterConfig(DataSize.gigabytes(1).toBytes(), Duration.ofNanos(1000), true)),
BACKUP_ATTACHMENT("backupAttachmentCreate", new RateLimiterConfig(10_000, Duration.ofSeconds(1), true)),
PRE_KEYS("prekeys", new RateLimiterConfig(6, Duration.ofMinutes(10), false)),
MESSAGES("messages", new RateLimiterConfig(60, Duration.ofSeconds(1), true)),
@ -115,6 +117,10 @@ public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
return forDescriptor(For.ATTACHMENT);
}
public RateLimiter getAttachmentBytesLimiter() {
return forDescriptor(For.ATTACHMENT_BYTES);
}
public RateLimiter getPinLimiter() {
return forDescriptor(For.PIN);
}

View File

@ -7,12 +7,16 @@ package org.whispersystems.textsecuregcm.controllers;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatNoException;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import io.dropwizard.auth.AuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import jakarta.ws.rs.core.Response;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URI;
@ -27,14 +31,15 @@ import java.security.spec.InvalidKeySpecException;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import jakarta.ws.rs.core.Response;
import org.assertj.core.api.Assertions;
import org.assertj.core.api.InstanceOfAssertFactories;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.attachments.AttachmentUtil;
import org.whispersystems.textsecuregcm.attachments.GcsAttachmentGenerator;
import org.whispersystems.textsecuregcm.attachments.TusAttachmentGenerator;
import org.whispersystems.textsecuregcm.attachments.TusConfiguration;
@ -44,7 +49,7 @@ import org.whispersystems.textsecuregcm.entities.AttachmentDescriptorV3;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.attachments.AttachmentUtil;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.MockUtils;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@ -53,11 +58,13 @@ import org.whispersystems.textsecuregcm.util.TestRandomUtil;
@ExtendWith(DropwizardExtensionsSupport.class)
class AttachmentControllerV4Test {
private static final RateLimiter RATE_LIMITER = mock(RateLimiter.class);
private static final RateLimiters RATE_LIMITERS = MockUtils.buildMock(RateLimiters.class, rateLimiters ->
when(rateLimiters.getAttachmentLimiter()).thenReturn(RATE_LIMITER));
private static final RateLimiter COUNT_RATE_LIMITER = mock(RateLimiter.class);
private static final RateLimiter BYTE_RATE_LIMITER = mock(RateLimiter.class);
private static final RateLimiters RATE_LIMITERS = MockUtils.buildMock(RateLimiters.class, rateLimiters -> {
when(rateLimiters.getAttachmentLimiter()).thenReturn(COUNT_RATE_LIMITER);
when(rateLimiters.getAttachmentBytesLimiter()).thenReturn(BYTE_RATE_LIMITER);
});
private static final String CDN3_ENABLED_CREDS = AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD);
private static final String CDN3_DISABLED_CREDS = AuthHelper.getAuthHeader(AuthHelper.VALID_UUID_TWO, AuthHelper.VALID_PASSWORD_TWO);
@ -98,6 +105,7 @@ class AttachmentControllerV4Test {
.addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class))
.setMapper(SystemMapper.jsonMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addProvider(new RateLimitExceededExceptionMapper())
.addProvider(new AttachmentControllerV4(RATE_LIMITERS,
gcsAttachmentGenerator,
new TusAttachmentGenerator(new TusConfiguration(new SecretBytes(TUS_SECRET), TUS_URL)),
@ -108,6 +116,11 @@ class AttachmentControllerV4Test {
}
}
@AfterEach
public void tearDown() {
reset(COUNT_RATE_LIMITER, BYTE_RATE_LIMITER);
}
@ParameterizedTest
@ValueSource(longs = {-1, 0})
void testInvalidUploadLength(final long uploadLength) {
@ -131,6 +144,44 @@ class AttachmentControllerV4Test {
.isEqualTo(Response.Status.REQUEST_ENTITY_TOO_LARGE.getStatusCode());
}
@Test
void missingUploadLengthDoesNotRateLimit() throws RateLimitExceededException {
doThrow(RateLimitExceededException.class).when(BYTE_RATE_LIMITER).validate(AuthHelper.VALID_UUID);
assertThatNoException().isThrownBy(() -> resources.getJerseyTest()
.target("/v4/attachments/form/upload")
.request()
.header("Authorization", CDN3_ENABLED_CREDS)
.get().getStatus());
}
@Test
void countRateLimitExceeded() throws RateLimitExceededException {
doThrow(RateLimitExceededException.class).when(COUNT_RATE_LIMITER).validate(AuthHelper.VALID_UUID);
assertThat(resources.getJerseyTest()
.target("/v4/attachments/form/upload")
.request()
.header("Authorization", CDN3_ENABLED_CREDS)
.get().getStatus())
.isEqualTo(Response.Status.TOO_MANY_REQUESTS.getStatusCode());
}
@Test
void rollbackRateLimit() throws RateLimitExceededException {
doThrow(RateLimitExceededException.class).when(BYTE_RATE_LIMITER)
.validate(AuthHelper.VALID_UUID, MAX_UPLOAD_LENGTH);
assertThat(resources.getJerseyTest()
.target("/v4/attachments/form/upload")
.queryParam("uploadLength", MAX_UPLOAD_LENGTH)
.request()
.header("Authorization", CDN3_ENABLED_CREDS)
.get().getStatus())
.isEqualTo(Response.Status.TOO_MANY_REQUESTS.getStatusCode());
verify(COUNT_RATE_LIMITER).validate(AuthHelper.VALID_UUID);
verify(COUNT_RATE_LIMITER).restorePermits(AuthHelper.VALID_UUID, 1);
}
@Test
void testV4TusForm() {
AttachmentDescriptorV3 descriptor = resources.getJerseyTest()

View File

@ -7,7 +7,9 @@ package org.whispersystems.textsecuregcm.grpc;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded;
@ -60,7 +62,9 @@ class AttachmentsGrpcServiceTest extends
@Mock
private ExperimentEnrollmentManager experimentEnrollmentManager;
@Mock
private RateLimiter rateLimiter;
private RateLimiter countRateLimiter;
@Mock
private RateLimiter byteRateLimiter;
@Override
protected AttachmentsGrpcService createServiceBeforeEachTest() {
@ -79,8 +83,10 @@ class AttachmentsGrpcServiceTest extends
return new AttachmentsGrpcService(
experimentEnrollmentManager,
MockUtils.buildMock(RateLimiters.class, rateLimiters ->
when(rateLimiters.getAttachmentLimiter()).thenReturn(rateLimiter)),
MockUtils.buildMock(RateLimiters.class, rateLimiters -> {
when(rateLimiters.getAttachmentLimiter()).thenReturn(countRateLimiter);
when(rateLimiters.getAttachmentBytesLimiter()).thenReturn(byteRateLimiter);
}),
gcsAttachmentGenerator,
tusAttachmentGenerator,
MAX_UPLOAD_LENGTH);
@ -103,6 +109,27 @@ class AttachmentsGrpcServiceTest extends
.hasExceedsMaxUploadLength()).isTrue();
}
@Test
void countRateLimitExceeded() throws RateLimitExceededException {
final Duration retryAfter = Duration.ofSeconds(17);
doThrow(new RateLimitExceededException(retryAfter)).when(countRateLimiter).validate(AUTHENTICATED_ACI);
GrpcTestUtils.assertRateLimitExceeded(retryAfter, () -> authenticatedServiceStub()
.getUploadForm(GetUploadFormRequest.newBuilder().setUploadLength(123).build()));
}
@Test
void rollbackRateLimit() throws RateLimitExceededException {
final Duration retryAfter = Duration.ofSeconds(17);
doThrow(new RateLimitExceededException(retryAfter)).when(byteRateLimiter)
.validate(AUTHENTICATED_ACI, 123);
GrpcTestUtils.assertRateLimitExceeded(retryAfter, () -> authenticatedServiceStub()
.getUploadForm(GetUploadFormRequest.newBuilder().setUploadLength(123).build()));
verify(countRateLimiter).validate(AUTHENTICATED_ACI);
verify(countRateLimiter).restorePermits(AUTHENTICATED_ACI, 1);
}
@Test
void getUploadFormCdn3() {
when(experimentEnrollmentManager.isEnrolled(AUTHENTICATED_ACI, AttachmentUtil.CDN3_EXPERIMENT_NAME))
@ -169,10 +196,15 @@ class AttachmentsGrpcServiceTest extends
assertThat(credentialParts[4]).isEqualTo("goog4_request");
}
@Test
void getUploadFormRateLimited() throws RateLimitExceededException {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void getUploadFormRateLimited(boolean useByteRateLimiter) throws RateLimitExceededException {
final Duration retryAfter = Duration.ofMinutes(5);
doThrow(new RateLimitExceededException(retryAfter)).when(rateLimiter).validate(any(UUID.class));
if (useByteRateLimiter) {
doThrow(new RateLimitExceededException(retryAfter)).when(byteRateLimiter).validate(any(UUID.class), anyLong());
} else {
doThrow(new RateLimitExceededException(retryAfter)).when(countRateLimiter).validate(any(UUID.class));
}
assertRateLimitExceeded(retryAfter, () ->
authenticatedServiceStub().getUploadForm(GetUploadFormRequest.newBuilder().setUploadLength(1).build()));

View File

@ -76,6 +76,30 @@ class LeakyBucketRateLimiterTest {
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key));
}
@ParameterizedTest
@ValueSource(ints = {1, 2, 3, 4, 100})
void restorePermits(int permitsToAdd) throws RateLimitExceededException {
final int bucketSize = 3;
final LeakyBucketRateLimiter rateLimiter = new LeakyBucketRateLimiter(
"test",
() -> new RateLimiterConfig(bucketSize, Duration.ofHours(1), false),
validateRateLimitScript,
REDIS_CLUSTER_EXTENSION.getRedisCluster(),
retryExecutor,
CLOCK);
final String key = RandomStringUtils.insecure().nextAlphanumeric(16);
for (int i = 0; i < bucketSize; i++) {
rateLimiter.validate(key);
}
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key));
rateLimiter.restorePermits(key, permitsToAdd);
for (int i = 0; i < Math.min(permitsToAdd, bucketSize); i++) {
assertDoesNotThrow(() -> rateLimiter.validate(key));
}
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key));
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void validateAsync(final boolean failOpen) {