From 9a9b15ee0aa470220e9121b24cf67000884ec5c9 Mon Sep 17 00:00:00 2001 From: ravi-signal <99042880+ravi-signal@users.noreply.github.com> Date: Tue, 23 Jun 2026 13:08:39 -0500 Subject: [PATCH] Add per-element constraint validation --- .../grpc/AccountsGrpcService.java | 6 - .../grpc/ValidatingInterceptor.java | 161 ++++++++++++------ .../validators/Base64UrlFieldValidator.java | 6 +- .../grpc/validators/BaseFieldValidator.java | 143 ---------------- .../grpc/validators/E164FieldValidator.java | 2 +- .../EnumSpecifiedFieldValidator.java | 2 +- .../validators/ExactlySizeFieldValidator.java | 10 +- .../grpc/validators/FieldValidator.java | 129 +++++++++++++- .../validators/NonEmptyFieldValidator.java | 8 +- .../validators/PresentFieldValidator.java | 2 +- .../grpc/validators/RangeFieldValidator.java | 2 +- ...erviceIdentifierIdentityTypeValidator.java | 2 +- .../grpc/validators/SizeFieldValidator.java | 8 +- .../main/proto/org/signal/chat/account.proto | 5 +- .../main/proto/org/signal/chat/messages.proto | 6 +- .../main/proto/org/signal/chat/require.proto | 35 +++- .../grpc/ValidatingInterceptorTest.java | 76 ++++++++- service/src/test/proto/validation_test.proto | 11 +- 18 files changed, 377 insertions(+), 237 deletions(-) delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/BaseFieldValidator.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/AccountsGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/AccountsGrpcService.java index f1c1eacd8..15a9b0b88 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/AccountsGrpcService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/AccountsGrpcService.java @@ -48,7 +48,6 @@ import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticationUtil; -import org.whispersystems.textsecuregcm.controllers.AccountController; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.IdentityType; @@ -136,11 +135,6 @@ public class AccountsGrpcService extends SimpleAccountsGrpc.AccountsImplBase { final List usernameHashes = new ArrayList<>(request.getUsernameHashesCount()); for (final ByteString usernameHash : request.getUsernameHashesList()) { - if (usernameHash.size() != AccountController.USERNAME_HASH_LENGTH) { - throw GrpcExceptions.fieldViolation("username_hashes", - String.format("Username hash length must be %d bytes, but was actually %d", - AccountController.USERNAME_HASH_LENGTH, usernameHash.size())); - } usernameHashes.add(usernameHash.toByteArray()); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ValidatingInterceptor.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ValidatingInterceptor.java index aec84aa4c..f5f1e02f7 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ValidatingInterceptor.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ValidatingInterceptor.java @@ -12,17 +12,19 @@ import io.grpc.Metadata; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; import io.grpc.ServerInterceptor; +import io.grpc.Status; import io.grpc.StatusRuntimeException; import java.util.List; import java.util.Map; +import org.signal.chat.require.ElementConstraint; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.grpc.validators.Base64UrlFieldValidator; +import org.whispersystems.textsecuregcm.grpc.validators.FieldValidator; import org.whispersystems.textsecuregcm.grpc.validators.E164FieldValidator; import org.whispersystems.textsecuregcm.grpc.validators.EnumSpecifiedFieldValidator; import org.whispersystems.textsecuregcm.grpc.validators.ExactlySizeFieldValidator; import org.whispersystems.textsecuregcm.grpc.validators.FieldValidationException; -import org.whispersystems.textsecuregcm.grpc.validators.FieldValidator; import org.whispersystems.textsecuregcm.grpc.validators.NonEmptyFieldValidator; import org.whispersystems.textsecuregcm.grpc.validators.PresentFieldValidator; import org.whispersystems.textsecuregcm.grpc.validators.RangeFieldValidator; @@ -32,17 +34,22 @@ import org.whispersystems.textsecuregcm.grpc.validators.SizeFieldValidator; public class ValidatingInterceptor implements ServerInterceptor { private static final Logger log = LoggerFactory.getLogger(ValidatingInterceptor.class); - private final Map fieldValidators = Map.of( - "org.signal.chat.require.nonEmpty", new NonEmptyFieldValidator(), - "org.signal.chat.require.present", new PresentFieldValidator(), - "org.signal.chat.require.specified", new EnumSpecifiedFieldValidator(), - "org.signal.chat.require.e164", new E164FieldValidator(), - "org.signal.chat.require.exactlySize", new ExactlySizeFieldValidator(), - "org.signal.chat.require.range", new RangeFieldValidator(), - "org.signal.chat.require.size", new SizeFieldValidator(), - "org.signal.chat.require.base64url", new Base64UrlFieldValidator(), - "org.signal.chat.require.identityType", new ServiceIdentifierIdentityTypeValidator() - ); + private static final String REQUIRE_PATH = "org.signal.chat.require."; + private static final String EACH_PATH = "org.signal.chat.require.each"; + + // The keys in this map correspond to the names of our custom FieldOptions as well as the names in the fields of + // `ElementConstraint` which itself is the value of the `each` FieldOption. For this reason it's a good idea to make + // the names match when they refer to the same validator. + private final static Map> VALIDATORS = Map.of( + "nonEmpty", new NonEmptyFieldValidator(), + "specified", new EnumSpecifiedFieldValidator(), + "e164", new E164FieldValidator(), + "exactlySize", new ExactlySizeFieldValidator(), + "range", new RangeFieldValidator(), + "size", new SizeFieldValidator(), + "base64url", new Base64UrlFieldValidator(), + "identityType", new ServiceIdentifierIdentityTypeValidator(), + "present", new PresentFieldValidator()); @Override public ServerCall.Listener interceptCall( @@ -67,13 +74,13 @@ public class ValidatingInterceptor implements ServerInterceptor { validateMessage(message); super.onMessage(message); } catch (RuntimeException runtimeException) { - final StatusRuntimeException grpcException = switch (runtimeException) { - case StatusRuntimeException e -> e; - default -> { - log.error("Failure applying request validation to message {}", call.getMethodDescriptor().getFullMethodName(), runtimeException); - yield GrpcExceptions.unavailable("failure applying request validation"); - } - }; + final StatusRuntimeException grpcException = runtimeException instanceof StatusRuntimeException sre + ? sre + : GrpcExceptions.unavailable("failure applying request validation"); + if (grpcException.getStatus().getCode() != Status.Code.INVALID_ARGUMENT) { + log.error("Failure applying request validation to message {}", + call.getMethodDescriptor().getFullMethodName(), runtimeException); + } call.close(grpcException.getStatus(), grpcException.getTrailers()); forwardCalls = false; } @@ -89,45 +96,95 @@ public class ValidatingInterceptor implements ServerInterceptor { } private void validateMessage(final Object message) { - if (message instanceof Message msg) { - for (final Descriptors.FieldDescriptor fd : msg.getDescriptorForType().getFields()) { - for (final Map.Entry entry : fd.getOptions().getAllFields().entrySet()) { - final Descriptors.FieldDescriptor extensionFieldDescriptor = entry.getKey(); - final String extensionName = extensionFieldDescriptor.getFullName(); + if (!(message instanceof Message msg)) { + return; + } + for (final Descriptors.FieldDescriptor fd : msg.getDescriptorForType().getFields()) { + for (final Map.Entry entry : fd.getOptions().getAllFields().entrySet()) { + final Descriptors.FieldDescriptor extensionFieldDescriptor = entry.getKey(); - // If this is a oneof, but this field isn't set, we shouldn't validate it. We assume if you have a validator - // that requires presence, you don't actually want presence enforcement on a oneof case. - if (fd.getRealContainingOneof() != null && !msg.hasField(fd)) { - continue; - } - - // first validate the field - final FieldValidator validator = fieldValidators.get(extensionName); - // not all extensions are validators, so `validator` value here could legitimately be `null` - if (validator != null) { - try { - validator.validate(entry.getValue(), fd, msg); - } catch (FieldValidationException e) { - throw GrpcExceptions.fieldViolation(fd.getName(), - "extension %s: %s".formatted(extensionName, e.getMessage())); - } + if (!extensionFieldDescriptor.getFullName().startsWith(REQUIRE_PATH) || (fd.getRealContainingOneof() != null && !msg.hasField(fd))) { + // Either this is a non-require extension, or this is oneof but this field isn't set. In the latter case + // we assume if you have a validator that requires presence, you don't actually want presence enforcement on + // a oneof. In either case we should just skip this validator. + } else if (extensionFieldDescriptor.getFullName().equals(EACH_PATH)) { + if (!(entry.getValue() instanceof ElementConstraint elementConstraint)) { + throw new IllegalStateException("'each' value must be an ElementConstraint"); } + validateRepeatedElementConstraints(elementConstraint, msg, fd); + } else { + validateField(getValidatorOrThrow(extensionFieldDescriptor), entry.getValue(), msg, fd); } + } - // Recursively validate the field's value(s) if it is a message or a repeated field - // gRPC's proto deserialization limits nesting to 100 so this has bounded stack usage - if (fd.isRepeated() && msg.getField(fd) instanceof List list) { - // Checking for repeated fields also handles maps, because maps are syntax sugar for repeated MapEntries - // which themselves are Messages that will be recursively descended. - for (final Object o : list) { - validateMessage(o); - } - } else if (fd.hasPresence() && msg.hasField(fd)) { - // If the field has presence information and is present, recursively validate it. Not all fields have - // presence, but we only validate Message type fields anyway, which always have explicit presence. - validateMessage(msg.getField(fd)); + // Recursively validate the field's value(s) if it is a message or a repeated field + // gRPC's proto deserialization limits nesting to 100 so this has bounded stack usage + if (fd.isRepeated() && msg.getField(fd) instanceof List list) { + // Checking for repeated fields also handles maps, because maps are syntax sugar for repeated MapEntries + // which themselves are Messages that will be recursively descended. + for (final Object o : list) { + validateMessage(o); + } + } else if (fd.hasPresence() && msg.hasField(fd)) { + // If the field has presence information and is present, recursively validate it. Not all fields have + // presence, but we only validate Message type fields anyway, which always have explicit presence. + validateMessage(msg.getField(fd)); + } + } + } + + private void validateField(final FieldValidator validator, final Object extensionValue, final Message msg, final Descriptors.FieldDescriptor fd) { + // for the fields with an `optional` modifier, checking if the field was set + // and if not, checking if extension allows missing optional field + if (fd.hasPresence() && !msg.hasField(fd)) { + switch (validator.getMissingOptionalAction()) { + case FAIL -> throw fieldViolation(fd, validator.getExtensionName(), "extension requires a value to be set"); + case SUCCEED -> { + return; + } + case VALIDATE_DEFAULT_VALUE -> {} + } + } + + try { + validator.validate(extensionValue, fd, msg.getField(fd)); + } catch (FieldValidationException e) { + throw fieldViolation(fd, validator.getExtensionName(), e.getMessage()); + } + } + + private void validateRepeatedElementConstraints( + final ElementConstraint elementConstraint, + final Message message, + final Descriptors.FieldDescriptor fd) { + if (!fd.isRepeated() || fd.isMapField()) { + throw new IllegalStateException("each may only be applied to repeated fields"); + } + + for (final Map.Entry entry : elementConstraint.getAllFields().entrySet()) { + final FieldValidator elementValidator = getValidatorOrThrow(entry.getKey()); + final int count = message.getRepeatedFieldCount(fd); + for (int i = 0; i < count; i++) { + final Object element = message.getRepeatedField(fd, i); + try { + elementValidator.validate(entry.getValue(), fd, element); + } catch (final FieldValidationException e) { + throw fieldViolation(fd, entry.getKey().getFullName(), "element [%d]: %s".formatted(i, e.getMessage())); } } } } + + private StatusRuntimeException fieldViolation(final Descriptors.FieldDescriptor fd, final String extensionName, final String message) { + return GrpcExceptions.fieldViolation(fd.getName(), "extension %s: %s".formatted(extensionName, message)); + } + + private static FieldValidator getValidatorOrThrow(final Descriptors.FieldDescriptor extensionFd) { + final FieldValidator validator = VALIDATORS.get(extensionFd.getName()); + if (validator == null) { + throw new IllegalStateException("unrecognized extension " + extensionFd.getFullName()); + } + return validator; + } + } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/Base64UrlFieldValidator.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/Base64UrlFieldValidator.java index 4197a75c9..b853c70c4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/Base64UrlFieldValidator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/Base64UrlFieldValidator.java @@ -6,16 +6,12 @@ package org.whispersystems.textsecuregcm.grpc.validators; import com.google.protobuf.Descriptors; -import org.whispersystems.textsecuregcm.util.ImpossiblePhoneNumberException; -import org.whispersystems.textsecuregcm.util.NonNormalizedPhoneNumberException; -import org.whispersystems.textsecuregcm.util.Util; import java.util.Base64; -import java.util.Objects; import java.util.Set; /// Validate that a string field is a valid base64 url string (padded or unpadded) -public class Base64UrlFieldValidator extends BaseFieldValidator { +public class Base64UrlFieldValidator extends FieldValidator { public Base64UrlFieldValidator() { super("base64url", Set.of(Descriptors.FieldDescriptor.Type.STRING), MissingOptionalAction.SUCCEED, false); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/BaseFieldValidator.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/BaseFieldValidator.java deleted file mode 100644 index f614a6806..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/BaseFieldValidator.java +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Copyright 2023 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.grpc.validators; - -import static java.util.Objects.requireNonNull; - -import com.google.protobuf.ByteString; -import com.google.protobuf.Descriptors; -import com.google.protobuf.Message; -import java.util.Set; - -public abstract class BaseFieldValidator implements FieldValidator { - - private final String extensionName; - - private final Set supportedTypes; - - private final MissingOptionalAction missingOptionalAction; - - private final boolean applicableToRepeated; - - protected enum MissingOptionalAction { - FAIL, - SUCCEED, - VALIDATE_DEFAULT_VALUE - } - - - protected BaseFieldValidator( - final String extensionName, - final Set supportedTypes, - final MissingOptionalAction missingOptionalAction, - final boolean applicableToRepeated) { - this.extensionName = requireNonNull(extensionName); - this.supportedTypes = requireNonNull(supportedTypes); - this.missingOptionalAction = missingOptionalAction; - this.applicableToRepeated = applicableToRepeated; - } - - @Override - public void validate( - final Object extensionValue, - final Descriptors.FieldDescriptor fd, - final Message msg) throws FieldValidationException { - final T extensionValueTyped = resolveExtensionValue(extensionValue); - - // for the fields with an `optional` modifier, checking if the field was set - // and if not, checking if extension allows missing optional field - if (fd.hasPresence() && !msg.hasField(fd)) { - switch (missingOptionalAction) { - case FAIL -> { - throw new FieldValidationException("extension requires a value to be set"); - } - case SUCCEED -> { - return; - } - case VALIDATE_DEFAULT_VALUE -> { - // just continuing - } - } - } - - // for the `repeated` fields, checking if it's supported by the extension - if (fd.isRepeated()) { - if (applicableToRepeated) { - validateRepeatedField(extensionValueTyped, fd, msg); - return; - } - throw new IllegalArgumentException("can't apply extension %s to `repeated` field %s" - .formatted(extensionName, fd.getFullName())); - } - - // checking field type against the set of supported types - final Descriptors.FieldDescriptor.Type type = fd.getType(); - if (!supportedTypes.contains(type)) { - throw new IllegalArgumentException("can't apply extension %s to field %s of type %s".formatted( - extensionName, fd.getFullName(), type)); - } - switch (type) { - case INT64, UINT64, INT32, FIXED64, FIXED32, UINT32, SFIXED32, SFIXED64, SINT32, SINT64 -> - validateIntegerNumber(extensionValueTyped, ((Number) msg.getField(fd)).longValue(), type); - case STRING -> validateStringValue(extensionValueTyped, (String) msg.getField(fd)); - case BYTES -> validateBytesValue(extensionValueTyped, (ByteString) msg.getField(fd)); - case ENUM -> validateEnumValue(extensionValueTyped, (Descriptors.EnumValueDescriptor) msg.getField(fd)); - case MESSAGE -> { - validateMessageValue(extensionValueTyped, (Message) msg.getField(fd)); - } - case FLOAT, DOUBLE, BOOL, GROUP -> { - // at this moment, there are no validations specific to these types of fields - } - } - - } - - protected abstract T resolveExtensionValue(final Object extensionValue) throws FieldValidationException; - - protected void validateRepeatedField( - final T extensionValue, - final Descriptors.FieldDescriptor fd, - final Message msg) throws FieldValidationException { - throw new UnsupportedOperationException("`validateRepeatedField` method needs to be implemented"); - } - - protected void validateIntegerNumber( - final T extensionValue, - final long fieldValue, final Descriptors.FieldDescriptor.Type type) throws FieldValidationException { - throw new UnsupportedOperationException("`validateIntegerNumber` method needs to be implemented"); - } - - protected void validateStringValue( - final T extensionValue, - final String fieldValue) throws FieldValidationException { - throw new UnsupportedOperationException("`validateStringValue` method needs to be implemented"); - } - - protected void validateBytesValue( - final T extensionValue, - final ByteString fieldValue) throws FieldValidationException { - throw new UnsupportedOperationException("`validateBytesValue` method needs to be implemented"); - } - - protected void validateEnumValue( - final T extensionValue, - final Descriptors.EnumValueDescriptor enumValueDescriptor) throws FieldValidationException { - throw new UnsupportedOperationException("`validateEnumValue` method needs to be implemented"); - } - - protected void validateMessageValue( - final T extensionValue, - final Message message) throws FieldValidationException { - throw new UnsupportedOperationException("`validateMessageValue` method needs to be implemented"); - } - - protected static boolean requireFlagExtension(final Object extensionValue) throws FieldValidationException { - if (extensionValue instanceof Boolean flagIsOn && flagIsOn) { - return true; - } - throw new UnsupportedOperationException("only value `true` is allowed"); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/E164FieldValidator.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/E164FieldValidator.java index 424b0484e..606c9fa5c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/E164FieldValidator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/E164FieldValidator.java @@ -11,7 +11,7 @@ import org.whispersystems.textsecuregcm.util.ImpossiblePhoneNumberException; import org.whispersystems.textsecuregcm.util.NonNormalizedPhoneNumberException; import org.whispersystems.textsecuregcm.util.Util; -public class E164FieldValidator extends BaseFieldValidator { +public class E164FieldValidator extends FieldValidator { public E164FieldValidator() { super("e164", Set.of(Descriptors.FieldDescriptor.Type.STRING), MissingOptionalAction.SUCCEED, false); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/EnumSpecifiedFieldValidator.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/EnumSpecifiedFieldValidator.java index 5bf104943..66bc19dda 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/EnumSpecifiedFieldValidator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/EnumSpecifiedFieldValidator.java @@ -8,7 +8,7 @@ package org.whispersystems.textsecuregcm.grpc.validators; import com.google.protobuf.Descriptors; import java.util.Set; -public class EnumSpecifiedFieldValidator extends BaseFieldValidator { +public class EnumSpecifiedFieldValidator extends FieldValidator { public EnumSpecifiedFieldValidator() { super("specified", Set.of(Descriptors.FieldDescriptor.Type.ENUM), MissingOptionalAction.FAIL, false); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/ExactlySizeFieldValidator.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/ExactlySizeFieldValidator.java index d3402d024..913e8f5e3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/ExactlySizeFieldValidator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/ExactlySizeFieldValidator.java @@ -7,11 +7,10 @@ package org.whispersystems.textsecuregcm.grpc.validators; import com.google.protobuf.ByteString; import com.google.protobuf.Descriptors; -import com.google.protobuf.Message; import java.util.List; import java.util.Set; -public class ExactlySizeFieldValidator extends BaseFieldValidator> { +public class ExactlySizeFieldValidator extends FieldValidator> { public ExactlySizeFieldValidator() { super("exactlySize", Set.of( @@ -50,11 +49,10 @@ public class ExactlySizeFieldValidator extends BaseFieldValidator> protected void validateRepeatedField( final Set permittedSizes, final Descriptors.FieldDescriptor fd, - final Message msg) throws FieldValidationException { - final int size = msg.getRepeatedFieldCount(fd); - if (permittedSizes.contains(size)) { + final List repeated) throws FieldValidationException { + if (permittedSizes.contains(repeated.size())) { return; } - throw new FieldValidationException("list size is [%d] but expected to be one of %s".formatted(size, permittedSizes)); + throw new FieldValidationException("list size is [%d] but expected to be one of %s".formatted(repeated.size(), permittedSizes)); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/FieldValidator.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/FieldValidator.java index c52faecbc..f1a550d58 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/FieldValidator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/FieldValidator.java @@ -5,11 +5,134 @@ package org.whispersystems.textsecuregcm.grpc.validators; +import static java.util.Objects.requireNonNull; + +import com.google.protobuf.ByteString; import com.google.protobuf.Descriptors; import com.google.protobuf.Message; +import java.util.List; +import java.util.Set; -public interface FieldValidator { +public abstract class FieldValidator { - void validate(Object extensionValue, Descriptors.FieldDescriptor fd, Message msg) - throws FieldValidationException; + private final String extensionName; + + private final Set supportedTypes; + + private final MissingOptionalAction missingOptionalAction; + + private final boolean applicableToRepeated; + + public enum MissingOptionalAction { + FAIL, + SUCCEED, + VALIDATE_DEFAULT_VALUE + } + + + protected FieldValidator( + final String extensionName, + final Set supportedTypes, + final MissingOptionalAction missingOptionalAction, + final boolean applicableToRepeated) { + this.extensionName = requireNonNull(extensionName); + this.supportedTypes = requireNonNull(supportedTypes); + this.missingOptionalAction = missingOptionalAction; + this.applicableToRepeated = applicableToRepeated; + } + + public MissingOptionalAction getMissingOptionalAction() { + return missingOptionalAction; + } + + public String getExtensionName() { + return extensionName; + } + + + /// Validate a field + /// + /// @param extensionValue The value of the option this field was annotated with + /// @param fd The descriptor of the field to validate + /// @param fieldValue The value of the field to validate + /// @throws FieldValidationException if a field constraint was violated + public final void validate( + final Object extensionValue, + final Descriptors.FieldDescriptor fd, + final Object fieldValue) throws FieldValidationException { + final T extensionValueTyped = resolveExtensionValue(extensionValue); + + // for the `repeated` fields, checking if it's supported by the extension + if (fd.isRepeated() && fieldValue instanceof List list) { + if (!applicableToRepeated) { + throw new IllegalArgumentException("can't apply extension %s to `repeated` field %s" + .formatted(extensionName, fd.getFullName())); + } + validateRepeatedField(extensionValueTyped, fd, list); + return; + } + final Descriptors.FieldDescriptor.Type type = fd.getType(); + if (!supportedTypes.contains(type)) { + throw new IllegalArgumentException("can't apply extension %s to field %s of type %s".formatted( + extensionName, fd.getFullName(), type)); + } + switch (type) { + case INT64, UINT64, INT32, FIXED64, FIXED32, UINT32, SFIXED32, SFIXED64, SINT32, SINT64 -> + validateIntegerNumber(extensionValueTyped, ((Number) fieldValue).longValue(), type); + case STRING -> validateStringValue(extensionValueTyped, (String) fieldValue); + case BYTES -> validateBytesValue(extensionValueTyped, (ByteString) fieldValue); + case ENUM -> validateEnumValue(extensionValueTyped, (Descriptors.EnumValueDescriptor) fieldValue); + case MESSAGE -> validateMessageValue(extensionValueTyped, (Message) fieldValue); + case FLOAT, DOUBLE, BOOL, GROUP -> { + // at this moment, there are no validations specific to these types of fields + } + } + + } + + protected abstract T resolveExtensionValue(final Object extensionValue) throws FieldValidationException; + + protected void validateRepeatedField( + final T extensionValue, + final Descriptors.FieldDescriptor fd, + final List repeated) throws FieldValidationException { + throw new UnsupportedOperationException("`validateRepeatedField` method needs to be implemented"); + } + + protected void validateIntegerNumber( + final T extensionValue, + final long fieldValue, final Descriptors.FieldDescriptor.Type type) throws FieldValidationException { + throw new UnsupportedOperationException("`validateIntegerNumber` method needs to be implemented"); + } + + protected void validateStringValue( + final T extensionValue, + final String fieldValue) throws FieldValidationException { + throw new UnsupportedOperationException("`validateStringValue` method needs to be implemented"); + } + + protected void validateBytesValue( + final T extensionValue, + final ByteString fieldValue) throws FieldValidationException { + throw new UnsupportedOperationException("`validateBytesValue` method needs to be implemented"); + } + + protected void validateEnumValue( + final T extensionValue, + final Descriptors.EnumValueDescriptor enumValueDescriptor) throws FieldValidationException { + throw new UnsupportedOperationException("`validateEnumValue` method needs to be implemented"); + } + + protected void validateMessageValue( + final T extensionValue, + final Message message) throws FieldValidationException { + throw new UnsupportedOperationException("`validateMessageValue` method needs to be implemented"); + } + + protected static boolean requireFlagExtension(final Object extensionValue) throws FieldValidationException { + if (extensionValue instanceof Boolean flagIsOn && flagIsOn) { + return true; + } + throw new UnsupportedOperationException("only value `true` is allowed"); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/NonEmptyFieldValidator.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/NonEmptyFieldValidator.java index 6989d7aab..62d7d83fb 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/NonEmptyFieldValidator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/NonEmptyFieldValidator.java @@ -7,11 +7,11 @@ package org.whispersystems.textsecuregcm.grpc.validators; import com.google.protobuf.ByteString; import com.google.protobuf.Descriptors; -import com.google.protobuf.Message; +import java.util.List; import java.util.Set; import org.apache.commons.lang3.StringUtils; -public class NonEmptyFieldValidator extends BaseFieldValidator { +public class NonEmptyFieldValidator extends FieldValidator { public NonEmptyFieldValidator() { super("nonEmpty", Set.of( @@ -49,8 +49,8 @@ public class NonEmptyFieldValidator extends BaseFieldValidator { protected void validateRepeatedField( final Boolean extensionValue, final Descriptors.FieldDescriptor fd, - final Message msg) throws FieldValidationException { - if (msg.getRepeatedFieldCount(fd) > 0) { + final List repeated) throws FieldValidationException { + if (repeated.size() > 0) { return; } throw new FieldValidationException("repeated field is expected to be non-empty"); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/PresentFieldValidator.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/PresentFieldValidator.java index e025dca9a..a25e05e95 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/PresentFieldValidator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/PresentFieldValidator.java @@ -9,7 +9,7 @@ import com.google.protobuf.Descriptors; import com.google.protobuf.Message; import java.util.Set; -public class PresentFieldValidator extends BaseFieldValidator { +public class PresentFieldValidator extends FieldValidator { public PresentFieldValidator() { super("present", diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/RangeFieldValidator.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/RangeFieldValidator.java index 0af3a84d7..22ba15ba5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/RangeFieldValidator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/RangeFieldValidator.java @@ -9,7 +9,7 @@ import com.google.protobuf.Descriptors; import java.util.Set; import org.signal.chat.require.ValueRangeConstraint; -public class RangeFieldValidator extends BaseFieldValidator { +public class RangeFieldValidator extends FieldValidator { private static final Set UNSIGNED_TYPES = Set.of( Descriptors.FieldDescriptor.Type.FIXED32, diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/ServiceIdentifierIdentityTypeValidator.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/ServiceIdentifierIdentityTypeValidator.java index 6adc29fa3..9fb889d69 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/ServiceIdentifierIdentityTypeValidator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/ServiceIdentifierIdentityTypeValidator.java @@ -12,7 +12,7 @@ import java.util.Set; import org.signal.chat.common.IdentityType; import org.signal.chat.common.ServiceIdentifier; -public class ServiceIdentifierIdentityTypeValidator extends BaseFieldValidator { +public class ServiceIdentifierIdentityTypeValidator extends FieldValidator { public ServiceIdentifierIdentityTypeValidator() { super("identityType", Set.of(Descriptors.FieldDescriptor.Type.MESSAGE), MissingOptionalAction.SUCCEED, false); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/SizeFieldValidator.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/SizeFieldValidator.java index bfe1dbcbc..17c884c1a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/SizeFieldValidator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/validators/SizeFieldValidator.java @@ -7,11 +7,11 @@ package org.whispersystems.textsecuregcm.grpc.validators; import com.google.protobuf.ByteString; import com.google.protobuf.Descriptors; -import com.google.protobuf.Message; +import java.util.List; import java.util.Set; import org.signal.chat.require.SizeConstraint; -public class SizeFieldValidator extends BaseFieldValidator { +public class SizeFieldValidator extends FieldValidator { public SizeFieldValidator() { super("size", Set.of( @@ -45,8 +45,8 @@ public class SizeFieldValidator extends BaseFieldValidator { } @Override - protected void validateRepeatedField(final Range range, final Descriptors.FieldDescriptor fd, final Message msg) throws FieldValidationException { - final int size = msg.getRepeatedFieldCount(fd); + protected void validateRepeatedField(final Range range, final Descriptors.FieldDescriptor fd, final List repeated) throws FieldValidationException { + final int size = repeated.size(); if (size < range.min() || size > range.max()) { throw new FieldValidationException("field value is [%d] but expected to be within the [%d, %d] range".formatted( size, range.min(), range.max())); diff --git a/service/src/main/proto/org/signal/chat/account.proto b/service/src/main/proto/org/signal/chat/account.proto index 3bb16aba7..e08173f01 100644 --- a/service/src/main/proto/org/signal/chat/account.proto +++ b/service/src/main/proto/org/signal/chat/account.proto @@ -105,9 +105,8 @@ message ClearRegistrationLockResponse { } message ReserveUsernameHashRequest { - // A prioritized list of username hashes to attempt to reserve. Each hash must - // be exactly 32 bytes. - repeated bytes username_hashes = 1 [(require.size) = {min: 1, max: 20}]; + // A prioritized list of username hashes to attempt to reserve. + repeated bytes username_hashes = 1 [(require.size) = {min: 1, max: 20}, (require.each) = {exactlySize: 32}]; } message UsernameNotAvailable {} diff --git a/service/src/main/proto/org/signal/chat/messages.proto b/service/src/main/proto/org/signal/chat/messages.proto index ca39f303d..2ba5b2389 100644 --- a/service/src/main/proto/org/signal/chat/messages.proto +++ b/service/src/main/proto/org/signal/chat/messages.proto @@ -385,19 +385,19 @@ message MismatchedDevices { // A list of device IDs that are linked to the destination account, but were // not included in the collection of messages bound for the destination // account. - repeated uint32 missing_devices = 2 [(require.range).max = 0x7f]; + repeated uint32 missing_devices = 2 [(require.each) = {range: {max: 0x7f}}]; // A list of device IDs that were included in the collection of messages bound // for the destination account, but are not currently linked to the // destination account. - repeated uint32 extra_devices = 3 [(require.range).max = 0x7f]; + repeated uint32 extra_devices = 3 [(require.each) = {range: {max: 0x7f}}]; // A list of device IDs that present in the collection of messages bound for // the destination account and are linked to the destination account, but have // a different registration ID than the registration ID presented by the // sender (indicating that the destination device has likely been replaced by // another device). - repeated uint32 stale_devices = 4 [(require.range).max = 0x7f]; + repeated uint32 stale_devices = 4 [(require.each) = {range: {max: 0x7f}}]; } message MultiRecipientMismatchedDevices { diff --git a/service/src/main/proto/org/signal/chat/require.proto b/service/src/main/proto/org/signal/chat/require.proto index f3b8c5a57..868a3feb9 100644 --- a/service/src/main/proto/org/signal/chat/require.proto +++ b/service/src/main/proto/org/signal/chat/require.proto @@ -178,7 +178,40 @@ extend google.protobuf.FieldOptions { */ optional IdentityType identityType = 70009; - // next 70010 + /* + * Applies element-wise constraints to the elements of a `repeated` field. + * Top-level `require.*` annotations on a `repeated` field constrain the + * collection itself (e.g. element count), `each` constrains the + * individual elements. + * + * ``` + * import "org/signal/chat/require.proto"; + * + * message Data { + * // 1-20 username hashes, each exactly 32 bytes + * repeated bytes username_hashes = 1 [ + * (require.size) = {min: 1, max: 20}, + * (require.each) = { exactlySize: 32 } + * ]; + * } + * ``` + * + * Applicable only to `repeated` fields. + */ + optional ElementConstraint each = 70010; + + // next 70011 +} + +message ElementConstraint { + optional bool nonEmpty = 1; + optional SizeConstraint size = 2; + repeated uint32 exactlySize = 3; + optional bool e164 = 4; + optional bool base64url = 5; + optional ValueRangeConstraint range = 6; + optional IdentityType identityType = 7; + optional bool specified = 8; } message SizeConstraint { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ValidatingInterceptorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ValidatingInterceptorTest.java index a8629c55e..2fa19d557 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ValidatingInterceptorTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ValidatingInterceptorTest.java @@ -445,6 +445,79 @@ public class ValidatingInterceptorTest { stub.validationsEndpoint(builderWithValidDefaults().build()); } + @Test + public void eachExactlySize() { + assertDoesNotThrow(() -> stub.validationsEndpoint(builderWithValidDefaults().build())); + + assertDoesNotThrow(() -> stub.validationsEndpoint(builderWithValidDefaults() + .addEachExactlyBytes(ByteString.copyFrom(new byte[4])) + .addEachExactlyBytes(ByteString.copyFrom(new byte[4])) + .build())); + + // second element is the wrong length + assertStatusException(Status.INVALID_ARGUMENT, () -> stub.validationsEndpoint(builderWithValidDefaults() + .addEachExactlyBytes(ByteString.copyFrom(new byte[4])) + .addEachExactlyBytes(ByteString.copyFrom(new byte[3])) + .build())); + } + + @Test + public void eachRange() { + assertDoesNotThrow(() -> stub.validationsEndpoint(builderWithValidDefaults() + .addEachRange(1) + .addEachRange(127) + .build())); + + assertStatusException(Status.INVALID_ARGUMENT, () -> stub.validationsEndpoint(builderWithValidDefaults() + .addEachRange(0) + .build())); + assertStatusException(Status.INVALID_ARGUMENT, () -> stub.validationsEndpoint(builderWithValidDefaults() + .addEachRange(128) + .build())); + } + + @Test + public void eachNonEmpty() { + assertDoesNotThrow(() -> stub.validationsEndpoint(builderWithValidDefaults() + .addEachNonEmpty("a") + .build())); + + assertStatusException(Status.INVALID_ARGUMENT, () -> stub.validationsEndpoint(builderWithValidDefaults() + .addEachNonEmpty("a") + .addEachNonEmpty("") + .build())); + } + + @Test + public void eachCombinedWithCollectionConstraint() { + // empty list fails the collection-level `size {min:1}` constraint + assertStatusException(Status.INVALID_ARGUMENT, () -> stub.validationsEndpoint(builderWithValidDefaults() + .clearEachCombined() + .build())); + + // too many elements (collection-level `size {max:3}`) + assertStatusException(Status.INVALID_ARGUMENT, () -> stub.validationsEndpoint(builderWithValidDefaults() + .clearEachCombined() + .addEachCombined(ByteString.copyFrom(new byte[4])) + .addEachCombined(ByteString.copyFrom(new byte[4])) + .addEachCombined(ByteString.copyFrom(new byte[4])) + .addEachCombined(ByteString.copyFrom(new byte[4])) + .build())); + + // count is in range but an element is the wrong length + assertStatusException(Status.INVALID_ARGUMENT, () -> stub.validationsEndpoint(builderWithValidDefaults() + .clearEachCombined() + .addEachCombined(ByteString.copyFrom(new byte[3])) + .build())); + + // both constraints satisfied + assertDoesNotThrow(() -> stub.validationsEndpoint(builderWithValidDefaults() + .clearEachCombined() + .addEachCombined(ByteString.copyFrom(new byte[4])) + .addEachCombined(ByteString.copyFrom(new byte[4])) + .build())); + } + @Test public void testFailedValidationOnNestedMessage() { assertStatusException(Status.INVALID_ARGUMENT, () -> @@ -528,7 +601,8 @@ public class ValidatingInterceptorTest { .setPniServiceIdentifier(ServiceIdentifier.newBuilder() .setIdentityType(IdentityType.IDENTITY_TYPE_PNI) .setUuid(UUIDUtil.toByteString(UUID.randomUUID())) - .build()); + .build()) + .addEachCombined(ByteString.copyFrom(new byte[4])); } private static void assertStatusException(final Status expected, final Executable serviceCall) { diff --git a/service/src/test/proto/validation_test.proto b/service/src/test/proto/validation_test.proto index 6ab6520d9..8e46bef60 100644 --- a/service/src/test/proto/validation_test.proto +++ b/service/src/test/proto/validation_test.proto @@ -84,7 +84,16 @@ message ValidationsRequest { optional common.ServiceIdentifier aciServiceIdentifier = 34 [(require.identityType) = IDENTITY_TYPE_ACI]; optional common.ServiceIdentifier pniServiceIdentifier = 35 [(require.identityType) = IDENTITY_TYPE_PNI]; - // next 36 + repeated bytes eachExactlyBytes = 36 [(require.each) = { exactlySize: 4 }]; + repeated uint32 eachRange = 37 [(require.each) = { range: { min: 1, max: 127 } }]; + repeated string eachNonEmpty = 38 [(require.each) = { nonEmpty: true }]; + // collection-level size combined with per-element size + repeated bytes eachCombined = 39 [ + (require.size) = { min: 1, max: 3 }, + (require.each) = { exactlySize: 4 } + ]; + + // next 40 } message NestedMessage {