Add byte-based rate-limit to attachments
This commit is contained in:
parent
f9d3cd8d82
commit
0ee06d83b7
@ -44,7 +44,8 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
|
||||
public class AttachmentControllerV4 {
|
||||
|
||||
private final ExperimentEnrollmentManager experimentEnrollmentManager;
|
||||
private final RateLimiter rateLimiter;
|
||||
private final RateLimiter countRateLimiter;
|
||||
private final RateLimiter bytesRateLimiter;
|
||||
private final long maxUploadLength;
|
||||
|
||||
private final Map<Integer, AttachmentGenerator> attachmentGenerators;
|
||||
@ -58,7 +59,8 @@ public class AttachmentControllerV4 {
|
||||
final TusAttachmentGenerator tusAttachmentGenerator,
|
||||
final ExperimentEnrollmentManager experimentEnrollmentManager,
|
||||
final long maxUploadLength) {
|
||||
this.rateLimiter = rateLimiters.getAttachmentLimiter();
|
||||
this.countRateLimiter = rateLimiters.getAttachmentLimiter();
|
||||
this.bytesRateLimiter = rateLimiters.getAttachmentBytesLimiter();
|
||||
this.experimentEnrollmentManager = experimentEnrollmentManager;
|
||||
this.maxUploadLength = maxUploadLength;
|
||||
this.secureRandom = new SecureRandom();
|
||||
@ -91,11 +93,23 @@ public class AttachmentControllerV4 {
|
||||
@Parameter(description = "The size of the attachment to upload in bytes")
|
||||
@QueryParam("uploadLength") final @Valid Optional<@Positive Long> maybeUploadLength)
|
||||
throws RateLimitExceededException {
|
||||
|
||||
final long uploadLength = maybeUploadLength.orElse(maxUploadLength);
|
||||
if (uploadLength > maxUploadLength) {
|
||||
throw new ClientErrorException("exceeded maximum uploadLength", Response.Status.REQUEST_ENTITY_TOO_LARGE);
|
||||
}
|
||||
rateLimiter.validate(auth.accountIdentifier());
|
||||
|
||||
countRateLimiter.validate(auth.accountIdentifier());
|
||||
if (maybeUploadLength.isPresent()) {
|
||||
// Ideally we'd check these two rate limits transactionally and only update them if both permits were acquired.
|
||||
// However, just undoing the first modification if the second one fails is close enough for our purposes
|
||||
try {
|
||||
bytesRateLimiter.validate(auth.accountIdentifier(), maybeUploadLength.get());
|
||||
} catch (RateLimitExceededException e) {
|
||||
countRateLimiter.restorePermits(auth.accountIdentifier(), 1);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
final String key = AttachmentUtil.generateAttachmentKey(secureRandom);
|
||||
final boolean useCdn3 = this.experimentEnrollmentManager.isEnrolled(auth.accountIdentifier(), AttachmentUtil.CDN3_EXPERIMENT_NAME);
|
||||
|
||||
@ -26,7 +26,8 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
|
||||
public class AttachmentsGrpcService extends SimpleAttachmentsGrpc.AttachmentsImplBase {
|
||||
|
||||
private final ExperimentEnrollmentManager experimentEnrollmentManager;
|
||||
private final RateLimiter rateLimiter;
|
||||
private final RateLimiter countRateLimiter;
|
||||
private final RateLimiter bytesRateLimiter;
|
||||
private final long maxUploadLength;
|
||||
private final Map<Integer, AttachmentGenerator> attachmentGenerators;
|
||||
private final SecureRandom secureRandom;
|
||||
@ -38,7 +39,8 @@ public class AttachmentsGrpcService extends SimpleAttachmentsGrpc.AttachmentsImp
|
||||
final TusAttachmentGenerator tusAttachmentGenerator,
|
||||
final long maxUploadLength) {
|
||||
this.experimentEnrollmentManager = experimentEnrollmentManager;
|
||||
this.rateLimiter = rateLimiters.getAttachmentLimiter();
|
||||
this.countRateLimiter = rateLimiters.getAttachmentLimiter();
|
||||
this.bytesRateLimiter = rateLimiters.getAttachmentBytesLimiter();
|
||||
this.maxUploadLength = maxUploadLength;
|
||||
this.secureRandom = new SecureRandom();
|
||||
this.attachmentGenerators = Map.of(
|
||||
@ -54,7 +56,16 @@ public class AttachmentsGrpcService extends SimpleAttachmentsGrpc.AttachmentsImp
|
||||
.build();
|
||||
}
|
||||
final AuthenticatedDevice auth = AuthenticationUtil.requireAuthenticatedDevice();
|
||||
rateLimiter.validate(auth.accountIdentifier());
|
||||
|
||||
countRateLimiter.validate(auth.accountIdentifier());
|
||||
try {
|
||||
// Ideally we'd check these two rate limits transactionally and only update them if both permits were acquired.
|
||||
// However, just undoing the first modification if the second one fails is close enough for our purposes
|
||||
bytesRateLimiter.validate(auth.accountIdentifier(), request.getUploadLength());
|
||||
} catch (RateLimitExceededException e) {
|
||||
countRateLimiter.restorePermits(auth.accountIdentifier(), 1);
|
||||
throw e;
|
||||
}
|
||||
|
||||
final String key = AttachmentUtil.generateAttachmentKey(secureRandom);
|
||||
final boolean useCdn3 = this.experimentEnrollmentManager.isEnrolled(auth.accountIdentifier(),
|
||||
|
||||
@ -8,7 +8,6 @@ package org.whispersystems.textsecuregcm.limits;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletionStage;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
public interface RateLimiter {
|
||||
@ -78,4 +77,27 @@ public interface RateLimiter {
|
||||
default CompletionStage<Void> clearAsync(final UUID accountUuid) {
|
||||
return clearAsync(accountUuid.toString());
|
||||
}
|
||||
|
||||
/// Restore `permitsToRestore` permits to the pool. If the RateLimiter has a maximum bucket size, it is respected even
|
||||
/// if `permitsToRestore` would exceed the maximum bucket size.
|
||||
///
|
||||
/// @param key The key to restore permits for.
|
||||
/// @param permitsToRestore The number of permits to restore
|
||||
/// @implNote The default implementation of this method assumes permits can be restored by calling `validate` with a
|
||||
/// negative permit count. If the `validate` implementation does not support this, the implementor must provide a
|
||||
/// custom implementation of [this#restorePermits].
|
||||
default void restorePermits(final String key, int permitsToRestore) {
|
||||
if (permitsToRestore < 0) {
|
||||
throw new IllegalArgumentException("permits to restore must be non-negative");
|
||||
}
|
||||
try {
|
||||
validate(key, -1 * permitsToRestore);
|
||||
} catch (RateLimitExceededException e) {
|
||||
throw new IllegalStateException("Out of permits when trying to restore permits", e);
|
||||
}
|
||||
}
|
||||
|
||||
default void restorePermits(final UUID accountUuid, int permitsToRestore) {
|
||||
restorePermits(accountUuid.toString(), permitsToRestore);
|
||||
}
|
||||
}
|
||||
|
||||
@ -8,6 +8,7 @@ import com.google.common.annotations.VisibleForTesting;
|
||||
import java.time.Clock;
|
||||
import java.time.Duration;
|
||||
import java.util.concurrent.ScheduledExecutorService;
|
||||
import io.dropwizard.util.DataSize;
|
||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
|
||||
@ -19,6 +20,7 @@ public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
|
||||
BACKUP_AUTH_CHECK("backupAuthCheck", new RateLimiterConfig(100, Duration.ofMinutes(15), true)),
|
||||
PIN("pin", new RateLimiterConfig(10, Duration.ofDays(1), false)),
|
||||
ATTACHMENT("attachmentCreate", new RateLimiterConfig(50, Duration.ofMillis(1200), true)),
|
||||
ATTACHMENT_BYTES("attachmentCreateBytes", new RateLimiterConfig(DataSize.gigabytes(1).toBytes(), Duration.ofNanos(1000), true)),
|
||||
BACKUP_ATTACHMENT("backupAttachmentCreate", new RateLimiterConfig(10_000, Duration.ofSeconds(1), true)),
|
||||
PRE_KEYS("prekeys", new RateLimiterConfig(6, Duration.ofMinutes(10), false)),
|
||||
MESSAGES("messages", new RateLimiterConfig(60, Duration.ofSeconds(1), true)),
|
||||
@ -115,6 +117,10 @@ public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
|
||||
return forDescriptor(For.ATTACHMENT);
|
||||
}
|
||||
|
||||
public RateLimiter getAttachmentBytesLimiter() {
|
||||
return forDescriptor(For.ATTACHMENT_BYTES);
|
||||
}
|
||||
|
||||
public RateLimiter getPinLimiter() {
|
||||
return forDescriptor(For.PIN);
|
||||
}
|
||||
|
||||
@ -7,12 +7,16 @@ package org.whispersystems.textsecuregcm.controllers;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatNoException;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.reset;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import io.dropwizard.auth.AuthValueFactoryProvider;
|
||||
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
|
||||
import io.dropwizard.testing.junit5.ResourceExtension;
|
||||
import jakarta.ws.rs.core.Response;
|
||||
import java.io.IOException;
|
||||
import java.net.MalformedURLException;
|
||||
import java.net.URI;
|
||||
@ -27,14 +31,15 @@ import java.security.spec.InvalidKeySpecException;
|
||||
import java.util.Base64;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import jakarta.ws.rs.core.Response;
|
||||
import org.assertj.core.api.Assertions;
|
||||
import org.assertj.core.api.InstanceOfAssertFactories;
|
||||
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.whispersystems.textsecuregcm.attachments.AttachmentUtil;
|
||||
import org.whispersystems.textsecuregcm.attachments.GcsAttachmentGenerator;
|
||||
import org.whispersystems.textsecuregcm.attachments.TusAttachmentGenerator;
|
||||
import org.whispersystems.textsecuregcm.attachments.TusConfiguration;
|
||||
@ -44,7 +49,7 @@ import org.whispersystems.textsecuregcm.entities.AttachmentDescriptorV3;
|
||||
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
|
||||
import org.whispersystems.textsecuregcm.limits.RateLimiter;
|
||||
import org.whispersystems.textsecuregcm.limits.RateLimiters;
|
||||
import org.whispersystems.textsecuregcm.attachments.AttachmentUtil;
|
||||
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
|
||||
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
|
||||
import org.whispersystems.textsecuregcm.util.MockUtils;
|
||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||
@ -53,11 +58,13 @@ import org.whispersystems.textsecuregcm.util.TestRandomUtil;
|
||||
@ExtendWith(DropwizardExtensionsSupport.class)
|
||||
class AttachmentControllerV4Test {
|
||||
|
||||
private static final RateLimiter RATE_LIMITER = mock(RateLimiter.class);
|
||||
|
||||
private static final RateLimiters RATE_LIMITERS = MockUtils.buildMock(RateLimiters.class, rateLimiters ->
|
||||
when(rateLimiters.getAttachmentLimiter()).thenReturn(RATE_LIMITER));
|
||||
private static final RateLimiter COUNT_RATE_LIMITER = mock(RateLimiter.class);
|
||||
private static final RateLimiter BYTE_RATE_LIMITER = mock(RateLimiter.class);
|
||||
|
||||
private static final RateLimiters RATE_LIMITERS = MockUtils.buildMock(RateLimiters.class, rateLimiters -> {
|
||||
when(rateLimiters.getAttachmentLimiter()).thenReturn(COUNT_RATE_LIMITER);
|
||||
when(rateLimiters.getAttachmentBytesLimiter()).thenReturn(BYTE_RATE_LIMITER);
|
||||
});
|
||||
|
||||
private static final String CDN3_ENABLED_CREDS = AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD);
|
||||
private static final String CDN3_DISABLED_CREDS = AuthHelper.getAuthHeader(AuthHelper.VALID_UUID_TWO, AuthHelper.VALID_PASSWORD_TWO);
|
||||
@ -98,6 +105,7 @@ class AttachmentControllerV4Test {
|
||||
.addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class))
|
||||
.setMapper(SystemMapper.jsonMapper())
|
||||
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
|
||||
.addProvider(new RateLimitExceededExceptionMapper())
|
||||
.addProvider(new AttachmentControllerV4(RATE_LIMITERS,
|
||||
gcsAttachmentGenerator,
|
||||
new TusAttachmentGenerator(new TusConfiguration(new SecretBytes(TUS_SECRET), TUS_URL)),
|
||||
@ -108,6 +116,11 @@ class AttachmentControllerV4Test {
|
||||
}
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
public void tearDown() {
|
||||
reset(COUNT_RATE_LIMITER, BYTE_RATE_LIMITER);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(longs = {-1, 0})
|
||||
void testInvalidUploadLength(final long uploadLength) {
|
||||
@ -131,6 +144,44 @@ class AttachmentControllerV4Test {
|
||||
.isEqualTo(Response.Status.REQUEST_ENTITY_TOO_LARGE.getStatusCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
void missingUploadLengthDoesNotRateLimit() throws RateLimitExceededException {
|
||||
doThrow(RateLimitExceededException.class).when(BYTE_RATE_LIMITER).validate(AuthHelper.VALID_UUID);
|
||||
|
||||
assertThatNoException().isThrownBy(() -> resources.getJerseyTest()
|
||||
.target("/v4/attachments/form/upload")
|
||||
.request()
|
||||
.header("Authorization", CDN3_ENABLED_CREDS)
|
||||
.get().getStatus());
|
||||
}
|
||||
|
||||
@Test
|
||||
void countRateLimitExceeded() throws RateLimitExceededException {
|
||||
doThrow(RateLimitExceededException.class).when(COUNT_RATE_LIMITER).validate(AuthHelper.VALID_UUID);
|
||||
assertThat(resources.getJerseyTest()
|
||||
.target("/v4/attachments/form/upload")
|
||||
.request()
|
||||
.header("Authorization", CDN3_ENABLED_CREDS)
|
||||
.get().getStatus())
|
||||
.isEqualTo(Response.Status.TOO_MANY_REQUESTS.getStatusCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
void rollbackRateLimit() throws RateLimitExceededException {
|
||||
doThrow(RateLimitExceededException.class).when(BYTE_RATE_LIMITER)
|
||||
.validate(AuthHelper.VALID_UUID, MAX_UPLOAD_LENGTH);
|
||||
assertThat(resources.getJerseyTest()
|
||||
.target("/v4/attachments/form/upload")
|
||||
.queryParam("uploadLength", MAX_UPLOAD_LENGTH)
|
||||
.request()
|
||||
.header("Authorization", CDN3_ENABLED_CREDS)
|
||||
.get().getStatus())
|
||||
.isEqualTo(Response.Status.TOO_MANY_REQUESTS.getStatusCode());
|
||||
|
||||
verify(COUNT_RATE_LIMITER).validate(AuthHelper.VALID_UUID);
|
||||
verify(COUNT_RATE_LIMITER).restorePermits(AuthHelper.VALID_UUID, 1);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testV4TusForm() {
|
||||
AttachmentDescriptorV3 descriptor = resources.getJerseyTest()
|
||||
|
||||
@ -7,7 +7,9 @@ package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyLong;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded;
|
||||
|
||||
@ -60,7 +62,9 @@ class AttachmentsGrpcServiceTest extends
|
||||
@Mock
|
||||
private ExperimentEnrollmentManager experimentEnrollmentManager;
|
||||
@Mock
|
||||
private RateLimiter rateLimiter;
|
||||
private RateLimiter countRateLimiter;
|
||||
@Mock
|
||||
private RateLimiter byteRateLimiter;
|
||||
|
||||
@Override
|
||||
protected AttachmentsGrpcService createServiceBeforeEachTest() {
|
||||
@ -79,8 +83,10 @@ class AttachmentsGrpcServiceTest extends
|
||||
|
||||
return new AttachmentsGrpcService(
|
||||
experimentEnrollmentManager,
|
||||
MockUtils.buildMock(RateLimiters.class, rateLimiters ->
|
||||
when(rateLimiters.getAttachmentLimiter()).thenReturn(rateLimiter)),
|
||||
MockUtils.buildMock(RateLimiters.class, rateLimiters -> {
|
||||
when(rateLimiters.getAttachmentLimiter()).thenReturn(countRateLimiter);
|
||||
when(rateLimiters.getAttachmentBytesLimiter()).thenReturn(byteRateLimiter);
|
||||
}),
|
||||
gcsAttachmentGenerator,
|
||||
tusAttachmentGenerator,
|
||||
MAX_UPLOAD_LENGTH);
|
||||
@ -103,6 +109,27 @@ class AttachmentsGrpcServiceTest extends
|
||||
.hasExceedsMaxUploadLength()).isTrue();
|
||||
}
|
||||
|
||||
@Test
|
||||
void countRateLimitExceeded() throws RateLimitExceededException {
|
||||
final Duration retryAfter = Duration.ofSeconds(17);
|
||||
doThrow(new RateLimitExceededException(retryAfter)).when(countRateLimiter).validate(AUTHENTICATED_ACI);
|
||||
GrpcTestUtils.assertRateLimitExceeded(retryAfter, () -> authenticatedServiceStub()
|
||||
.getUploadForm(GetUploadFormRequest.newBuilder().setUploadLength(123).build()));
|
||||
}
|
||||
|
||||
@Test
|
||||
void rollbackRateLimit() throws RateLimitExceededException {
|
||||
final Duration retryAfter = Duration.ofSeconds(17);
|
||||
doThrow(new RateLimitExceededException(retryAfter)).when(byteRateLimiter)
|
||||
.validate(AUTHENTICATED_ACI, 123);
|
||||
GrpcTestUtils.assertRateLimitExceeded(retryAfter, () -> authenticatedServiceStub()
|
||||
.getUploadForm(GetUploadFormRequest.newBuilder().setUploadLength(123).build()));
|
||||
|
||||
verify(countRateLimiter).validate(AUTHENTICATED_ACI);
|
||||
verify(countRateLimiter).restorePermits(AUTHENTICATED_ACI, 1);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
void getUploadFormCdn3() {
|
||||
when(experimentEnrollmentManager.isEnrolled(AUTHENTICATED_ACI, AttachmentUtil.CDN3_EXPERIMENT_NAME))
|
||||
@ -169,10 +196,15 @@ class AttachmentsGrpcServiceTest extends
|
||||
assertThat(credentialParts[4]).isEqualTo("goog4_request");
|
||||
}
|
||||
|
||||
@Test
|
||||
void getUploadFormRateLimited() throws RateLimitExceededException {
|
||||
@ParameterizedTest
|
||||
@ValueSource(booleans = {true, false})
|
||||
void getUploadFormRateLimited(boolean useByteRateLimiter) throws RateLimitExceededException {
|
||||
final Duration retryAfter = Duration.ofMinutes(5);
|
||||
doThrow(new RateLimitExceededException(retryAfter)).when(rateLimiter).validate(any(UUID.class));
|
||||
if (useByteRateLimiter) {
|
||||
doThrow(new RateLimitExceededException(retryAfter)).when(byteRateLimiter).validate(any(UUID.class), anyLong());
|
||||
} else {
|
||||
doThrow(new RateLimitExceededException(retryAfter)).when(countRateLimiter).validate(any(UUID.class));
|
||||
}
|
||||
|
||||
assertRateLimitExceeded(retryAfter, () ->
|
||||
authenticatedServiceStub().getUploadForm(GetUploadFormRequest.newBuilder().setUploadLength(1).build()));
|
||||
|
||||
@ -76,6 +76,30 @@ class LeakyBucketRateLimiterTest {
|
||||
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(ints = {1, 2, 3, 4, 100})
|
||||
void restorePermits(int permitsToAdd) throws RateLimitExceededException {
|
||||
final int bucketSize = 3;
|
||||
final LeakyBucketRateLimiter rateLimiter = new LeakyBucketRateLimiter(
|
||||
"test",
|
||||
() -> new RateLimiterConfig(bucketSize, Duration.ofHours(1), false),
|
||||
validateRateLimitScript,
|
||||
REDIS_CLUSTER_EXTENSION.getRedisCluster(),
|
||||
retryExecutor,
|
||||
CLOCK);
|
||||
final String key = RandomStringUtils.insecure().nextAlphanumeric(16);
|
||||
for (int i = 0; i < bucketSize; i++) {
|
||||
rateLimiter.validate(key);
|
||||
}
|
||||
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key));
|
||||
|
||||
rateLimiter.restorePermits(key, permitsToAdd);
|
||||
for (int i = 0; i < Math.min(permitsToAdd, bucketSize); i++) {
|
||||
assertDoesNotThrow(() -> rateLimiter.validate(key));
|
||||
}
|
||||
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(booleans = {true, false})
|
||||
void validateAsync(final boolean failOpen) {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user