diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java index 16a56050a..54764a190 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java @@ -53,6 +53,7 @@ import java.util.concurrent.Executor; import java.util.function.Function; import java.util.stream.Collectors; import javax.annotation.Nullable; +import org.apache.commons.lang3.StringUtils; import org.glassfish.jersey.server.ManagedAsync; import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ServiceId; @@ -79,6 +80,7 @@ import org.whispersystems.textsecuregcm.entities.CreateProfileRequest; import org.whispersystems.textsecuregcm.entities.ExpiringProfileKeyCredentialProfileResponse; import org.whispersystems.textsecuregcm.entities.ProfileAvatarUploadAttributes; import org.whispersystems.textsecuregcm.entities.VersionedProfileResponse; +import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; @@ -94,6 +96,7 @@ import org.whispersystems.textsecuregcm.storage.DeviceCapability; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.ProfilesManager; import org.whispersystems.textsecuregcm.storage.VersionedProfile; +import org.whispersystems.textsecuregcm.util.ExceptionUtils; import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.ProfileHelper; @@ -122,6 +125,7 @@ public class ProfileController { private static final String VERSION_NOT_FOUND_COUNTER_NAME = name(ProfileController.class, "versionNotFound"); private static final String DUPLICATE_AUTHENTICATION_COUNTER_NAME = name(ProfileController.class, "duplicateAuthentication"); + private static final String BATCH_IDENTITY_CHECK_RATE_LIMITED_COUNTER_NAME = name(ProfileController.class, "batchIdentityCheckRateLimited"); public ProfileController( Clock clock, @@ -372,7 +376,20 @@ public class ProfileController { description = "Batch check completed successfully. Response may contain accounts with mismatched fingerprints.", content = @Content(schema = @Schema(implementation = BatchIdentityCheckResponse.class))) @ApiResponse(responseCode = "400", description = "Invalid request format or validation failed.") - public CompletableFuture runBatchIdentityCheck(@NotNull @Valid BatchIdentityCheckRequest request) { + public CompletableFuture runBatchIdentityCheck( + @NotNull @Valid final BatchIdentityCheckRequest request, + @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent, + @Context final ContainerRequestContext containerRequestContext) { + final String remoteAddress = (String) containerRequestContext.getProperty(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME); + if (StringUtils.isNotBlank(remoteAddress)) { + rateLimiters.getBatchIdentityCheckLimiter().validateAsync(remoteAddress) + .whenComplete((_, t) -> { + if (t != null && ExceptionUtils.unwrap(t) instanceof RateLimitExceededException) { + Metrics.counter(BATCH_IDENTITY_CHECK_RATE_LIMITED_COUNTER_NAME, + Tags.of(UserAgentTagUtil.getPlatformTag(userAgent))).increment(); + } + }); + } return CompletableFuture.supplyAsync(() -> { List responseElements = Collections.synchronizedList(new ArrayList<>()); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java index a388dd513..53d9e28cf 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java @@ -58,7 +58,8 @@ public class RateLimiters extends BaseRateLimiters { RECORD_DEVICE_TRANSFER_REQUEST("recordDeviceTransferRequest", new RateLimiterConfig(10, Duration.ofMillis(100), true)), WAIT_FOR_DEVICE_TRANSFER_REQUEST("waitForDeviceTransferRequest", new RateLimiterConfig(10, Duration.ofMillis(100), true)), DEVICE_CHECK_CHALLENGE("deviceCheckChallenge", new RateLimiterConfig(10, Duration.ofMinutes(1), false)), - SUBMIT_CALL_QUALITY_SURVEY("submitCallQualitySurvey", new RateLimiterConfig(100, Duration.ofMinutes(1), true)) + SUBMIT_CALL_QUALITY_SURVEY("submitCallQualitySurvey", new RateLimiterConfig(100, Duration.ofMinutes(1), true)), + BATCH_IDENTITY_CHECK("batchIdentityCheck", new RateLimiterConfig(100, Duration.ofMinutes(1), true)), ; private final String id; @@ -232,4 +233,8 @@ public class RateLimiters extends BaseRateLimiters { public RateLimiter getSubmitCallQualitySurveyLimiter() { return forDescriptor(For.SUBMIT_CALL_QUALITY_SURVEY); } + + public RateLimiter getBatchIdentityCheckLimiter() { + return forDescriptor(For.BATCH_IDENTITY_CHECK); + } }