Add per-element constraint validation
This commit is contained in:
parent
503941ec6a
commit
9a9b15ee0a
@ -48,7 +48,6 @@ import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
|
||||
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticationUtil;
|
||||
import org.whispersystems.textsecuregcm.controllers.AccountController;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||
@ -136,11 +135,6 @@ public class AccountsGrpcService extends SimpleAccountsGrpc.AccountsImplBase {
|
||||
final List<byte[]> usernameHashes = new ArrayList<>(request.getUsernameHashesCount());
|
||||
|
||||
for (final ByteString usernameHash : request.getUsernameHashesList()) {
|
||||
if (usernameHash.size() != AccountController.USERNAME_HASH_LENGTH) {
|
||||
throw GrpcExceptions.fieldViolation("username_hashes",
|
||||
String.format("Username hash length must be %d bytes, but was actually %d",
|
||||
AccountController.USERNAME_HASH_LENGTH, usernameHash.size()));
|
||||
}
|
||||
usernameHashes.add(usernameHash.toByteArray());
|
||||
}
|
||||
|
||||
|
||||
@ -12,17 +12,19 @@ import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.StatusRuntimeException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.signal.chat.require.ElementConstraint;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.grpc.validators.Base64UrlFieldValidator;
|
||||
import org.whispersystems.textsecuregcm.grpc.validators.FieldValidator;
|
||||
import org.whispersystems.textsecuregcm.grpc.validators.E164FieldValidator;
|
||||
import org.whispersystems.textsecuregcm.grpc.validators.EnumSpecifiedFieldValidator;
|
||||
import org.whispersystems.textsecuregcm.grpc.validators.ExactlySizeFieldValidator;
|
||||
import org.whispersystems.textsecuregcm.grpc.validators.FieldValidationException;
|
||||
import org.whispersystems.textsecuregcm.grpc.validators.FieldValidator;
|
||||
import org.whispersystems.textsecuregcm.grpc.validators.NonEmptyFieldValidator;
|
||||
import org.whispersystems.textsecuregcm.grpc.validators.PresentFieldValidator;
|
||||
import org.whispersystems.textsecuregcm.grpc.validators.RangeFieldValidator;
|
||||
@ -32,17 +34,22 @@ import org.whispersystems.textsecuregcm.grpc.validators.SizeFieldValidator;
|
||||
public class ValidatingInterceptor implements ServerInterceptor {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(ValidatingInterceptor.class);
|
||||
private final Map<String, FieldValidator> fieldValidators = Map.of(
|
||||
"org.signal.chat.require.nonEmpty", new NonEmptyFieldValidator(),
|
||||
"org.signal.chat.require.present", new PresentFieldValidator(),
|
||||
"org.signal.chat.require.specified", new EnumSpecifiedFieldValidator(),
|
||||
"org.signal.chat.require.e164", new E164FieldValidator(),
|
||||
"org.signal.chat.require.exactlySize", new ExactlySizeFieldValidator(),
|
||||
"org.signal.chat.require.range", new RangeFieldValidator(),
|
||||
"org.signal.chat.require.size", new SizeFieldValidator(),
|
||||
"org.signal.chat.require.base64url", new Base64UrlFieldValidator(),
|
||||
"org.signal.chat.require.identityType", new ServiceIdentifierIdentityTypeValidator()
|
||||
);
|
||||
private static final String REQUIRE_PATH = "org.signal.chat.require.";
|
||||
private static final String EACH_PATH = "org.signal.chat.require.each";
|
||||
|
||||
// The keys in this map correspond to the names of our custom FieldOptions as well as the names in the fields of
|
||||
// `ElementConstraint` which itself is the value of the `each` FieldOption. For this reason it's a good idea to make
|
||||
// the names match when they refer to the same validator.
|
||||
private final static Map<String, FieldValidator<?>> VALIDATORS = Map.of(
|
||||
"nonEmpty", new NonEmptyFieldValidator(),
|
||||
"specified", new EnumSpecifiedFieldValidator(),
|
||||
"e164", new E164FieldValidator(),
|
||||
"exactlySize", new ExactlySizeFieldValidator(),
|
||||
"range", new RangeFieldValidator(),
|
||||
"size", new SizeFieldValidator(),
|
||||
"base64url", new Base64UrlFieldValidator(),
|
||||
"identityType", new ServiceIdentifierIdentityTypeValidator(),
|
||||
"present", new PresentFieldValidator());
|
||||
|
||||
@Override
|
||||
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
|
||||
@ -67,13 +74,13 @@ public class ValidatingInterceptor implements ServerInterceptor {
|
||||
validateMessage(message);
|
||||
super.onMessage(message);
|
||||
} catch (RuntimeException runtimeException) {
|
||||
final StatusRuntimeException grpcException = switch (runtimeException) {
|
||||
case StatusRuntimeException e -> e;
|
||||
default -> {
|
||||
log.error("Failure applying request validation to message {}", call.getMethodDescriptor().getFullMethodName(), runtimeException);
|
||||
yield GrpcExceptions.unavailable("failure applying request validation");
|
||||
}
|
||||
};
|
||||
final StatusRuntimeException grpcException = runtimeException instanceof StatusRuntimeException sre
|
||||
? sre
|
||||
: GrpcExceptions.unavailable("failure applying request validation");
|
||||
if (grpcException.getStatus().getCode() != Status.Code.INVALID_ARGUMENT) {
|
||||
log.error("Failure applying request validation to message {}",
|
||||
call.getMethodDescriptor().getFullMethodName(), runtimeException);
|
||||
}
|
||||
call.close(grpcException.getStatus(), grpcException.getTrailers());
|
||||
forwardCalls = false;
|
||||
}
|
||||
@ -89,45 +96,95 @@ public class ValidatingInterceptor implements ServerInterceptor {
|
||||
}
|
||||
|
||||
private void validateMessage(final Object message) {
|
||||
if (message instanceof Message msg) {
|
||||
for (final Descriptors.FieldDescriptor fd : msg.getDescriptorForType().getFields()) {
|
||||
for (final Map.Entry<Descriptors.FieldDescriptor, Object> entry : fd.getOptions().getAllFields().entrySet()) {
|
||||
final Descriptors.FieldDescriptor extensionFieldDescriptor = entry.getKey();
|
||||
final String extensionName = extensionFieldDescriptor.getFullName();
|
||||
if (!(message instanceof Message msg)) {
|
||||
return;
|
||||
}
|
||||
for (final Descriptors.FieldDescriptor fd : msg.getDescriptorForType().getFields()) {
|
||||
for (final Map.Entry<Descriptors.FieldDescriptor, Object> entry : fd.getOptions().getAllFields().entrySet()) {
|
||||
final Descriptors.FieldDescriptor extensionFieldDescriptor = entry.getKey();
|
||||
|
||||
// If this is a oneof, but this field isn't set, we shouldn't validate it. We assume if you have a validator
|
||||
// that requires presence, you don't actually want presence enforcement on a oneof case.
|
||||
if (fd.getRealContainingOneof() != null && !msg.hasField(fd)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// first validate the field
|
||||
final FieldValidator validator = fieldValidators.get(extensionName);
|
||||
// not all extensions are validators, so `validator` value here could legitimately be `null`
|
||||
if (validator != null) {
|
||||
try {
|
||||
validator.validate(entry.getValue(), fd, msg);
|
||||
} catch (FieldValidationException e) {
|
||||
throw GrpcExceptions.fieldViolation(fd.getName(),
|
||||
"extension %s: %s".formatted(extensionName, e.getMessage()));
|
||||
}
|
||||
if (!extensionFieldDescriptor.getFullName().startsWith(REQUIRE_PATH) || (fd.getRealContainingOneof() != null && !msg.hasField(fd))) {
|
||||
// Either this is a non-require extension, or this is oneof but this field isn't set. In the latter case
|
||||
// we assume if you have a validator that requires presence, you don't actually want presence enforcement on
|
||||
// a oneof. In either case we should just skip this validator.
|
||||
} else if (extensionFieldDescriptor.getFullName().equals(EACH_PATH)) {
|
||||
if (!(entry.getValue() instanceof ElementConstraint elementConstraint)) {
|
||||
throw new IllegalStateException("'each' value must be an ElementConstraint");
|
||||
}
|
||||
validateRepeatedElementConstraints(elementConstraint, msg, fd);
|
||||
} else {
|
||||
validateField(getValidatorOrThrow(extensionFieldDescriptor), entry.getValue(), msg, fd);
|
||||
}
|
||||
}
|
||||
|
||||
// Recursively validate the field's value(s) if it is a message or a repeated field
|
||||
// gRPC's proto deserialization limits nesting to 100 so this has bounded stack usage
|
||||
if (fd.isRepeated() && msg.getField(fd) instanceof List list) {
|
||||
// Checking for repeated fields also handles maps, because maps are syntax sugar for repeated MapEntries
|
||||
// which themselves are Messages that will be recursively descended.
|
||||
for (final Object o : list) {
|
||||
validateMessage(o);
|
||||
}
|
||||
} else if (fd.hasPresence() && msg.hasField(fd)) {
|
||||
// If the field has presence information and is present, recursively validate it. Not all fields have
|
||||
// presence, but we only validate Message type fields anyway, which always have explicit presence.
|
||||
validateMessage(msg.getField(fd));
|
||||
// Recursively validate the field's value(s) if it is a message or a repeated field
|
||||
// gRPC's proto deserialization limits nesting to 100 so this has bounded stack usage
|
||||
if (fd.isRepeated() && msg.getField(fd) instanceof List list) {
|
||||
// Checking for repeated fields also handles maps, because maps are syntax sugar for repeated MapEntries
|
||||
// which themselves are Messages that will be recursively descended.
|
||||
for (final Object o : list) {
|
||||
validateMessage(o);
|
||||
}
|
||||
} else if (fd.hasPresence() && msg.hasField(fd)) {
|
||||
// If the field has presence information and is present, recursively validate it. Not all fields have
|
||||
// presence, but we only validate Message type fields anyway, which always have explicit presence.
|
||||
validateMessage(msg.getField(fd));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void validateField(final FieldValidator<?> validator, final Object extensionValue, final Message msg, final Descriptors.FieldDescriptor fd) {
|
||||
// for the fields with an `optional` modifier, checking if the field was set
|
||||
// and if not, checking if extension allows missing optional field
|
||||
if (fd.hasPresence() && !msg.hasField(fd)) {
|
||||
switch (validator.getMissingOptionalAction()) {
|
||||
case FAIL -> throw fieldViolation(fd, validator.getExtensionName(), "extension requires a value to be set");
|
||||
case SUCCEED -> {
|
||||
return;
|
||||
}
|
||||
case VALIDATE_DEFAULT_VALUE -> {}
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
validator.validate(extensionValue, fd, msg.getField(fd));
|
||||
} catch (FieldValidationException e) {
|
||||
throw fieldViolation(fd, validator.getExtensionName(), e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private void validateRepeatedElementConstraints(
|
||||
final ElementConstraint elementConstraint,
|
||||
final Message message,
|
||||
final Descriptors.FieldDescriptor fd) {
|
||||
if (!fd.isRepeated() || fd.isMapField()) {
|
||||
throw new IllegalStateException("each may only be applied to repeated fields");
|
||||
}
|
||||
|
||||
for (final Map.Entry<Descriptors.FieldDescriptor, Object> entry : elementConstraint.getAllFields().entrySet()) {
|
||||
final FieldValidator<?> elementValidator = getValidatorOrThrow(entry.getKey());
|
||||
final int count = message.getRepeatedFieldCount(fd);
|
||||
for (int i = 0; i < count; i++) {
|
||||
final Object element = message.getRepeatedField(fd, i);
|
||||
try {
|
||||
elementValidator.validate(entry.getValue(), fd, element);
|
||||
} catch (final FieldValidationException e) {
|
||||
throw fieldViolation(fd, entry.getKey().getFullName(), "element [%d]: %s".formatted(i, e.getMessage()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private StatusRuntimeException fieldViolation(final Descriptors.FieldDescriptor fd, final String extensionName, final String message) {
|
||||
return GrpcExceptions.fieldViolation(fd.getName(), "extension %s: %s".formatted(extensionName, message));
|
||||
}
|
||||
|
||||
private static FieldValidator<?> getValidatorOrThrow(final Descriptors.FieldDescriptor extensionFd) {
|
||||
final FieldValidator<?> validator = VALIDATORS.get(extensionFd.getName());
|
||||
if (validator == null) {
|
||||
throw new IllegalStateException("unrecognized extension " + extensionFd.getFullName());
|
||||
}
|
||||
return validator;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -6,16 +6,12 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.validators;
|
||||
|
||||
import com.google.protobuf.Descriptors;
|
||||
import org.whispersystems.textsecuregcm.util.ImpossiblePhoneNumberException;
|
||||
import org.whispersystems.textsecuregcm.util.NonNormalizedPhoneNumberException;
|
||||
import org.whispersystems.textsecuregcm.util.Util;
|
||||
|
||||
import java.util.Base64;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
/// Validate that a string field is a valid base64 url string (padded or unpadded)
|
||||
public class Base64UrlFieldValidator extends BaseFieldValidator<Boolean> {
|
||||
public class Base64UrlFieldValidator extends FieldValidator<Boolean> {
|
||||
|
||||
public Base64UrlFieldValidator() {
|
||||
super("base64url", Set.of(Descriptors.FieldDescriptor.Type.STRING), MissingOptionalAction.SUCCEED, false);
|
||||
|
||||
@ -1,143 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc.validators;
|
||||
|
||||
import static java.util.Objects.requireNonNull;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.google.protobuf.Descriptors;
|
||||
import com.google.protobuf.Message;
|
||||
import java.util.Set;
|
||||
|
||||
public abstract class BaseFieldValidator<T> implements FieldValidator {
|
||||
|
||||
private final String extensionName;
|
||||
|
||||
private final Set<Descriptors.FieldDescriptor.Type> supportedTypes;
|
||||
|
||||
private final MissingOptionalAction missingOptionalAction;
|
||||
|
||||
private final boolean applicableToRepeated;
|
||||
|
||||
protected enum MissingOptionalAction {
|
||||
FAIL,
|
||||
SUCCEED,
|
||||
VALIDATE_DEFAULT_VALUE
|
||||
}
|
||||
|
||||
|
||||
protected BaseFieldValidator(
|
||||
final String extensionName,
|
||||
final Set<Descriptors.FieldDescriptor.Type> supportedTypes,
|
||||
final MissingOptionalAction missingOptionalAction,
|
||||
final boolean applicableToRepeated) {
|
||||
this.extensionName = requireNonNull(extensionName);
|
||||
this.supportedTypes = requireNonNull(supportedTypes);
|
||||
this.missingOptionalAction = missingOptionalAction;
|
||||
this.applicableToRepeated = applicableToRepeated;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void validate(
|
||||
final Object extensionValue,
|
||||
final Descriptors.FieldDescriptor fd,
|
||||
final Message msg) throws FieldValidationException {
|
||||
final T extensionValueTyped = resolveExtensionValue(extensionValue);
|
||||
|
||||
// for the fields with an `optional` modifier, checking if the field was set
|
||||
// and if not, checking if extension allows missing optional field
|
||||
if (fd.hasPresence() && !msg.hasField(fd)) {
|
||||
switch (missingOptionalAction) {
|
||||
case FAIL -> {
|
||||
throw new FieldValidationException("extension requires a value to be set");
|
||||
}
|
||||
case SUCCEED -> {
|
||||
return;
|
||||
}
|
||||
case VALIDATE_DEFAULT_VALUE -> {
|
||||
// just continuing
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// for the `repeated` fields, checking if it's supported by the extension
|
||||
if (fd.isRepeated()) {
|
||||
if (applicableToRepeated) {
|
||||
validateRepeatedField(extensionValueTyped, fd, msg);
|
||||
return;
|
||||
}
|
||||
throw new IllegalArgumentException("can't apply extension %s to `repeated` field %s"
|
||||
.formatted(extensionName, fd.getFullName()));
|
||||
}
|
||||
|
||||
// checking field type against the set of supported types
|
||||
final Descriptors.FieldDescriptor.Type type = fd.getType();
|
||||
if (!supportedTypes.contains(type)) {
|
||||
throw new IllegalArgumentException("can't apply extension %s to field %s of type %s".formatted(
|
||||
extensionName, fd.getFullName(), type));
|
||||
}
|
||||
switch (type) {
|
||||
case INT64, UINT64, INT32, FIXED64, FIXED32, UINT32, SFIXED32, SFIXED64, SINT32, SINT64 ->
|
||||
validateIntegerNumber(extensionValueTyped, ((Number) msg.getField(fd)).longValue(), type);
|
||||
case STRING -> validateStringValue(extensionValueTyped, (String) msg.getField(fd));
|
||||
case BYTES -> validateBytesValue(extensionValueTyped, (ByteString) msg.getField(fd));
|
||||
case ENUM -> validateEnumValue(extensionValueTyped, (Descriptors.EnumValueDescriptor) msg.getField(fd));
|
||||
case MESSAGE -> {
|
||||
validateMessageValue(extensionValueTyped, (Message) msg.getField(fd));
|
||||
}
|
||||
case FLOAT, DOUBLE, BOOL, GROUP -> {
|
||||
// at this moment, there are no validations specific to these types of fields
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
protected abstract T resolveExtensionValue(final Object extensionValue) throws FieldValidationException;
|
||||
|
||||
protected void validateRepeatedField(
|
||||
final T extensionValue,
|
||||
final Descriptors.FieldDescriptor fd,
|
||||
final Message msg) throws FieldValidationException {
|
||||
throw new UnsupportedOperationException("`validateRepeatedField` method needs to be implemented");
|
||||
}
|
||||
|
||||
protected void validateIntegerNumber(
|
||||
final T extensionValue,
|
||||
final long fieldValue, final Descriptors.FieldDescriptor.Type type) throws FieldValidationException {
|
||||
throw new UnsupportedOperationException("`validateIntegerNumber` method needs to be implemented");
|
||||
}
|
||||
|
||||
protected void validateStringValue(
|
||||
final T extensionValue,
|
||||
final String fieldValue) throws FieldValidationException {
|
||||
throw new UnsupportedOperationException("`validateStringValue` method needs to be implemented");
|
||||
}
|
||||
|
||||
protected void validateBytesValue(
|
||||
final T extensionValue,
|
||||
final ByteString fieldValue) throws FieldValidationException {
|
||||
throw new UnsupportedOperationException("`validateBytesValue` method needs to be implemented");
|
||||
}
|
||||
|
||||
protected void validateEnumValue(
|
||||
final T extensionValue,
|
||||
final Descriptors.EnumValueDescriptor enumValueDescriptor) throws FieldValidationException {
|
||||
throw new UnsupportedOperationException("`validateEnumValue` method needs to be implemented");
|
||||
}
|
||||
|
||||
protected void validateMessageValue(
|
||||
final T extensionValue,
|
||||
final Message message) throws FieldValidationException {
|
||||
throw new UnsupportedOperationException("`validateMessageValue` method needs to be implemented");
|
||||
}
|
||||
|
||||
protected static boolean requireFlagExtension(final Object extensionValue) throws FieldValidationException {
|
||||
if (extensionValue instanceof Boolean flagIsOn && flagIsOn) {
|
||||
return true;
|
||||
}
|
||||
throw new UnsupportedOperationException("only value `true` is allowed");
|
||||
}
|
||||
}
|
||||
@ -11,7 +11,7 @@ import org.whispersystems.textsecuregcm.util.ImpossiblePhoneNumberException;
|
||||
import org.whispersystems.textsecuregcm.util.NonNormalizedPhoneNumberException;
|
||||
import org.whispersystems.textsecuregcm.util.Util;
|
||||
|
||||
public class E164FieldValidator extends BaseFieldValidator<Boolean> {
|
||||
public class E164FieldValidator extends FieldValidator<Boolean> {
|
||||
|
||||
public E164FieldValidator() {
|
||||
super("e164", Set.of(Descriptors.FieldDescriptor.Type.STRING), MissingOptionalAction.SUCCEED, false);
|
||||
|
||||
@ -8,7 +8,7 @@ package org.whispersystems.textsecuregcm.grpc.validators;
|
||||
import com.google.protobuf.Descriptors;
|
||||
import java.util.Set;
|
||||
|
||||
public class EnumSpecifiedFieldValidator extends BaseFieldValidator<Boolean> {
|
||||
public class EnumSpecifiedFieldValidator extends FieldValidator<Boolean> {
|
||||
|
||||
public EnumSpecifiedFieldValidator() {
|
||||
super("specified", Set.of(Descriptors.FieldDescriptor.Type.ENUM), MissingOptionalAction.FAIL, false);
|
||||
|
||||
@ -7,11 +7,10 @@ package org.whispersystems.textsecuregcm.grpc.validators;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.google.protobuf.Descriptors;
|
||||
import com.google.protobuf.Message;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
public class ExactlySizeFieldValidator extends BaseFieldValidator<Set<Integer>> {
|
||||
public class ExactlySizeFieldValidator extends FieldValidator<Set<Integer>> {
|
||||
|
||||
public ExactlySizeFieldValidator() {
|
||||
super("exactlySize", Set.of(
|
||||
@ -50,11 +49,10 @@ public class ExactlySizeFieldValidator extends BaseFieldValidator<Set<Integer>>
|
||||
protected void validateRepeatedField(
|
||||
final Set<Integer> permittedSizes,
|
||||
final Descriptors.FieldDescriptor fd,
|
||||
final Message msg) throws FieldValidationException {
|
||||
final int size = msg.getRepeatedFieldCount(fd);
|
||||
if (permittedSizes.contains(size)) {
|
||||
final List<?> repeated) throws FieldValidationException {
|
||||
if (permittedSizes.contains(repeated.size())) {
|
||||
return;
|
||||
}
|
||||
throw new FieldValidationException("list size is [%d] but expected to be one of %s".formatted(size, permittedSizes));
|
||||
throw new FieldValidationException("list size is [%d] but expected to be one of %s".formatted(repeated.size(), permittedSizes));
|
||||
}
|
||||
}
|
||||
|
||||
@ -5,11 +5,134 @@
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc.validators;
|
||||
|
||||
import static java.util.Objects.requireNonNull;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.google.protobuf.Descriptors;
|
||||
import com.google.protobuf.Message;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
public interface FieldValidator {
|
||||
public abstract class FieldValidator<T> {
|
||||
|
||||
void validate(Object extensionValue, Descriptors.FieldDescriptor fd, Message msg)
|
||||
throws FieldValidationException;
|
||||
private final String extensionName;
|
||||
|
||||
private final Set<Descriptors.FieldDescriptor.Type> supportedTypes;
|
||||
|
||||
private final MissingOptionalAction missingOptionalAction;
|
||||
|
||||
private final boolean applicableToRepeated;
|
||||
|
||||
public enum MissingOptionalAction {
|
||||
FAIL,
|
||||
SUCCEED,
|
||||
VALIDATE_DEFAULT_VALUE
|
||||
}
|
||||
|
||||
|
||||
protected FieldValidator(
|
||||
final String extensionName,
|
||||
final Set<Descriptors.FieldDescriptor.Type> supportedTypes,
|
||||
final MissingOptionalAction missingOptionalAction,
|
||||
final boolean applicableToRepeated) {
|
||||
this.extensionName = requireNonNull(extensionName);
|
||||
this.supportedTypes = requireNonNull(supportedTypes);
|
||||
this.missingOptionalAction = missingOptionalAction;
|
||||
this.applicableToRepeated = applicableToRepeated;
|
||||
}
|
||||
|
||||
public MissingOptionalAction getMissingOptionalAction() {
|
||||
return missingOptionalAction;
|
||||
}
|
||||
|
||||
public String getExtensionName() {
|
||||
return extensionName;
|
||||
}
|
||||
|
||||
|
||||
/// Validate a field
|
||||
///
|
||||
/// @param extensionValue The value of the option this field was annotated with
|
||||
/// @param fd The descriptor of the field to validate
|
||||
/// @param fieldValue The value of the field to validate
|
||||
/// @throws FieldValidationException if a field constraint was violated
|
||||
public final void validate(
|
||||
final Object extensionValue,
|
||||
final Descriptors.FieldDescriptor fd,
|
||||
final Object fieldValue) throws FieldValidationException {
|
||||
final T extensionValueTyped = resolveExtensionValue(extensionValue);
|
||||
|
||||
// for the `repeated` fields, checking if it's supported by the extension
|
||||
if (fd.isRepeated() && fieldValue instanceof List list) {
|
||||
if (!applicableToRepeated) {
|
||||
throw new IllegalArgumentException("can't apply extension %s to `repeated` field %s"
|
||||
.formatted(extensionName, fd.getFullName()));
|
||||
}
|
||||
validateRepeatedField(extensionValueTyped, fd, list);
|
||||
return;
|
||||
}
|
||||
final Descriptors.FieldDescriptor.Type type = fd.getType();
|
||||
if (!supportedTypes.contains(type)) {
|
||||
throw new IllegalArgumentException("can't apply extension %s to field %s of type %s".formatted(
|
||||
extensionName, fd.getFullName(), type));
|
||||
}
|
||||
switch (type) {
|
||||
case INT64, UINT64, INT32, FIXED64, FIXED32, UINT32, SFIXED32, SFIXED64, SINT32, SINT64 ->
|
||||
validateIntegerNumber(extensionValueTyped, ((Number) fieldValue).longValue(), type);
|
||||
case STRING -> validateStringValue(extensionValueTyped, (String) fieldValue);
|
||||
case BYTES -> validateBytesValue(extensionValueTyped, (ByteString) fieldValue);
|
||||
case ENUM -> validateEnumValue(extensionValueTyped, (Descriptors.EnumValueDescriptor) fieldValue);
|
||||
case MESSAGE -> validateMessageValue(extensionValueTyped, (Message) fieldValue);
|
||||
case FLOAT, DOUBLE, BOOL, GROUP -> {
|
||||
// at this moment, there are no validations specific to these types of fields
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
protected abstract T resolveExtensionValue(final Object extensionValue) throws FieldValidationException;
|
||||
|
||||
protected void validateRepeatedField(
|
||||
final T extensionValue,
|
||||
final Descriptors.FieldDescriptor fd,
|
||||
final List<?> repeated) throws FieldValidationException {
|
||||
throw new UnsupportedOperationException("`validateRepeatedField` method needs to be implemented");
|
||||
}
|
||||
|
||||
protected void validateIntegerNumber(
|
||||
final T extensionValue,
|
||||
final long fieldValue, final Descriptors.FieldDescriptor.Type type) throws FieldValidationException {
|
||||
throw new UnsupportedOperationException("`validateIntegerNumber` method needs to be implemented");
|
||||
}
|
||||
|
||||
protected void validateStringValue(
|
||||
final T extensionValue,
|
||||
final String fieldValue) throws FieldValidationException {
|
||||
throw new UnsupportedOperationException("`validateStringValue` method needs to be implemented");
|
||||
}
|
||||
|
||||
protected void validateBytesValue(
|
||||
final T extensionValue,
|
||||
final ByteString fieldValue) throws FieldValidationException {
|
||||
throw new UnsupportedOperationException("`validateBytesValue` method needs to be implemented");
|
||||
}
|
||||
|
||||
protected void validateEnumValue(
|
||||
final T extensionValue,
|
||||
final Descriptors.EnumValueDescriptor enumValueDescriptor) throws FieldValidationException {
|
||||
throw new UnsupportedOperationException("`validateEnumValue` method needs to be implemented");
|
||||
}
|
||||
|
||||
protected void validateMessageValue(
|
||||
final T extensionValue,
|
||||
final Message message) throws FieldValidationException {
|
||||
throw new UnsupportedOperationException("`validateMessageValue` method needs to be implemented");
|
||||
}
|
||||
|
||||
protected static boolean requireFlagExtension(final Object extensionValue) throws FieldValidationException {
|
||||
if (extensionValue instanceof Boolean flagIsOn && flagIsOn) {
|
||||
return true;
|
||||
}
|
||||
throw new UnsupportedOperationException("only value `true` is allowed");
|
||||
}
|
||||
}
|
||||
|
||||
@ -7,11 +7,11 @@ package org.whispersystems.textsecuregcm.grpc.validators;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.google.protobuf.Descriptors;
|
||||
import com.google.protobuf.Message;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
public class NonEmptyFieldValidator extends BaseFieldValidator<Boolean> {
|
||||
public class NonEmptyFieldValidator extends FieldValidator<Boolean> {
|
||||
|
||||
public NonEmptyFieldValidator() {
|
||||
super("nonEmpty", Set.of(
|
||||
@ -49,8 +49,8 @@ public class NonEmptyFieldValidator extends BaseFieldValidator<Boolean> {
|
||||
protected void validateRepeatedField(
|
||||
final Boolean extensionValue,
|
||||
final Descriptors.FieldDescriptor fd,
|
||||
final Message msg) throws FieldValidationException {
|
||||
if (msg.getRepeatedFieldCount(fd) > 0) {
|
||||
final List<?> repeated) throws FieldValidationException {
|
||||
if (repeated.size() > 0) {
|
||||
return;
|
||||
}
|
||||
throw new FieldValidationException("repeated field is expected to be non-empty");
|
||||
|
||||
@ -9,7 +9,7 @@ import com.google.protobuf.Descriptors;
|
||||
import com.google.protobuf.Message;
|
||||
import java.util.Set;
|
||||
|
||||
public class PresentFieldValidator extends BaseFieldValidator<Boolean> {
|
||||
public class PresentFieldValidator extends FieldValidator<Boolean> {
|
||||
|
||||
public PresentFieldValidator() {
|
||||
super("present",
|
||||
|
||||
@ -9,7 +9,7 @@ import com.google.protobuf.Descriptors;
|
||||
import java.util.Set;
|
||||
import org.signal.chat.require.ValueRangeConstraint;
|
||||
|
||||
public class RangeFieldValidator extends BaseFieldValidator<Range> {
|
||||
public class RangeFieldValidator extends FieldValidator<Range> {
|
||||
|
||||
private static final Set<Descriptors.FieldDescriptor.Type> UNSIGNED_TYPES = Set.of(
|
||||
Descriptors.FieldDescriptor.Type.FIXED32,
|
||||
|
||||
@ -12,7 +12,7 @@ import java.util.Set;
|
||||
import org.signal.chat.common.IdentityType;
|
||||
import org.signal.chat.common.ServiceIdentifier;
|
||||
|
||||
public class ServiceIdentifierIdentityTypeValidator extends BaseFieldValidator<IdentityType> {
|
||||
public class ServiceIdentifierIdentityTypeValidator extends FieldValidator<IdentityType> {
|
||||
|
||||
public ServiceIdentifierIdentityTypeValidator() {
|
||||
super("identityType", Set.of(Descriptors.FieldDescriptor.Type.MESSAGE), MissingOptionalAction.SUCCEED, false);
|
||||
|
||||
@ -7,11 +7,11 @@ package org.whispersystems.textsecuregcm.grpc.validators;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.google.protobuf.Descriptors;
|
||||
import com.google.protobuf.Message;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import org.signal.chat.require.SizeConstraint;
|
||||
|
||||
public class SizeFieldValidator extends BaseFieldValidator<Range> {
|
||||
public class SizeFieldValidator extends FieldValidator<Range> {
|
||||
|
||||
public SizeFieldValidator() {
|
||||
super("size", Set.of(
|
||||
@ -45,8 +45,8 @@ public class SizeFieldValidator extends BaseFieldValidator<Range> {
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void validateRepeatedField(final Range range, final Descriptors.FieldDescriptor fd, final Message msg) throws FieldValidationException {
|
||||
final int size = msg.getRepeatedFieldCount(fd);
|
||||
protected void validateRepeatedField(final Range range, final Descriptors.FieldDescriptor fd, final List<?> repeated) throws FieldValidationException {
|
||||
final int size = repeated.size();
|
||||
if (size < range.min() || size > range.max()) {
|
||||
throw new FieldValidationException("field value is [%d] but expected to be within the [%d, %d] range".formatted(
|
||||
size, range.min(), range.max()));
|
||||
|
||||
@ -105,9 +105,8 @@ message ClearRegistrationLockResponse {
|
||||
}
|
||||
|
||||
message ReserveUsernameHashRequest {
|
||||
// A prioritized list of username hashes to attempt to reserve. Each hash must
|
||||
// be exactly 32 bytes.
|
||||
repeated bytes username_hashes = 1 [(require.size) = {min: 1, max: 20}];
|
||||
// A prioritized list of username hashes to attempt to reserve.
|
||||
repeated bytes username_hashes = 1 [(require.size) = {min: 1, max: 20}, (require.each) = {exactlySize: 32}];
|
||||
}
|
||||
|
||||
message UsernameNotAvailable {}
|
||||
|
||||
@ -385,19 +385,19 @@ message MismatchedDevices {
|
||||
// A list of device IDs that are linked to the destination account, but were
|
||||
// not included in the collection of messages bound for the destination
|
||||
// account.
|
||||
repeated uint32 missing_devices = 2 [(require.range).max = 0x7f];
|
||||
repeated uint32 missing_devices = 2 [(require.each) = {range: {max: 0x7f}}];
|
||||
|
||||
// A list of device IDs that were included in the collection of messages bound
|
||||
// for the destination account, but are not currently linked to the
|
||||
// destination account.
|
||||
repeated uint32 extra_devices = 3 [(require.range).max = 0x7f];
|
||||
repeated uint32 extra_devices = 3 [(require.each) = {range: {max: 0x7f}}];
|
||||
|
||||
// A list of device IDs that present in the collection of messages bound for
|
||||
// the destination account and are linked to the destination account, but have
|
||||
// a different registration ID than the registration ID presented by the
|
||||
// sender (indicating that the destination device has likely been replaced by
|
||||
// another device).
|
||||
repeated uint32 stale_devices = 4 [(require.range).max = 0x7f];
|
||||
repeated uint32 stale_devices = 4 [(require.each) = {range: {max: 0x7f}}];
|
||||
}
|
||||
|
||||
message MultiRecipientMismatchedDevices {
|
||||
|
||||
@ -178,7 +178,40 @@ extend google.protobuf.FieldOptions {
|
||||
*/
|
||||
optional IdentityType identityType = 70009;
|
||||
|
||||
// next 70010
|
||||
/*
|
||||
* Applies element-wise constraints to the elements of a `repeated` field.
|
||||
* Top-level `require.*` annotations on a `repeated` field constrain the
|
||||
* collection itself (e.g. element count), `each` constrains the
|
||||
* individual elements.
|
||||
*
|
||||
* ```
|
||||
* import "org/signal/chat/require.proto";
|
||||
*
|
||||
* message Data {
|
||||
* // 1-20 username hashes, each exactly 32 bytes
|
||||
* repeated bytes username_hashes = 1 [
|
||||
* (require.size) = {min: 1, max: 20},
|
||||
* (require.each) = { exactlySize: 32 }
|
||||
* ];
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
* Applicable only to `repeated` fields.
|
||||
*/
|
||||
optional ElementConstraint each = 70010;
|
||||
|
||||
// next 70011
|
||||
}
|
||||
|
||||
message ElementConstraint {
|
||||
optional bool nonEmpty = 1;
|
||||
optional SizeConstraint size = 2;
|
||||
repeated uint32 exactlySize = 3;
|
||||
optional bool e164 = 4;
|
||||
optional bool base64url = 5;
|
||||
optional ValueRangeConstraint range = 6;
|
||||
optional IdentityType identityType = 7;
|
||||
optional bool specified = 8;
|
||||
}
|
||||
|
||||
message SizeConstraint {
|
||||
|
||||
@ -445,6 +445,79 @@ public class ValidatingInterceptorTest {
|
||||
stub.validationsEndpoint(builderWithValidDefaults().build());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void eachExactlySize() {
|
||||
assertDoesNotThrow(() -> stub.validationsEndpoint(builderWithValidDefaults().build()));
|
||||
|
||||
assertDoesNotThrow(() -> stub.validationsEndpoint(builderWithValidDefaults()
|
||||
.addEachExactlyBytes(ByteString.copyFrom(new byte[4]))
|
||||
.addEachExactlyBytes(ByteString.copyFrom(new byte[4]))
|
||||
.build()));
|
||||
|
||||
// second element is the wrong length
|
||||
assertStatusException(Status.INVALID_ARGUMENT, () -> stub.validationsEndpoint(builderWithValidDefaults()
|
||||
.addEachExactlyBytes(ByteString.copyFrom(new byte[4]))
|
||||
.addEachExactlyBytes(ByteString.copyFrom(new byte[3]))
|
||||
.build()));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void eachRange() {
|
||||
assertDoesNotThrow(() -> stub.validationsEndpoint(builderWithValidDefaults()
|
||||
.addEachRange(1)
|
||||
.addEachRange(127)
|
||||
.build()));
|
||||
|
||||
assertStatusException(Status.INVALID_ARGUMENT, () -> stub.validationsEndpoint(builderWithValidDefaults()
|
||||
.addEachRange(0)
|
||||
.build()));
|
||||
assertStatusException(Status.INVALID_ARGUMENT, () -> stub.validationsEndpoint(builderWithValidDefaults()
|
||||
.addEachRange(128)
|
||||
.build()));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void eachNonEmpty() {
|
||||
assertDoesNotThrow(() -> stub.validationsEndpoint(builderWithValidDefaults()
|
||||
.addEachNonEmpty("a")
|
||||
.build()));
|
||||
|
||||
assertStatusException(Status.INVALID_ARGUMENT, () -> stub.validationsEndpoint(builderWithValidDefaults()
|
||||
.addEachNonEmpty("a")
|
||||
.addEachNonEmpty("")
|
||||
.build()));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void eachCombinedWithCollectionConstraint() {
|
||||
// empty list fails the collection-level `size {min:1}` constraint
|
||||
assertStatusException(Status.INVALID_ARGUMENT, () -> stub.validationsEndpoint(builderWithValidDefaults()
|
||||
.clearEachCombined()
|
||||
.build()));
|
||||
|
||||
// too many elements (collection-level `size {max:3}`)
|
||||
assertStatusException(Status.INVALID_ARGUMENT, () -> stub.validationsEndpoint(builderWithValidDefaults()
|
||||
.clearEachCombined()
|
||||
.addEachCombined(ByteString.copyFrom(new byte[4]))
|
||||
.addEachCombined(ByteString.copyFrom(new byte[4]))
|
||||
.addEachCombined(ByteString.copyFrom(new byte[4]))
|
||||
.addEachCombined(ByteString.copyFrom(new byte[4]))
|
||||
.build()));
|
||||
|
||||
// count is in range but an element is the wrong length
|
||||
assertStatusException(Status.INVALID_ARGUMENT, () -> stub.validationsEndpoint(builderWithValidDefaults()
|
||||
.clearEachCombined()
|
||||
.addEachCombined(ByteString.copyFrom(new byte[3]))
|
||||
.build()));
|
||||
|
||||
// both constraints satisfied
|
||||
assertDoesNotThrow(() -> stub.validationsEndpoint(builderWithValidDefaults()
|
||||
.clearEachCombined()
|
||||
.addEachCombined(ByteString.copyFrom(new byte[4]))
|
||||
.addEachCombined(ByteString.copyFrom(new byte[4]))
|
||||
.build()));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFailedValidationOnNestedMessage() {
|
||||
assertStatusException(Status.INVALID_ARGUMENT, () ->
|
||||
@ -528,7 +601,8 @@ public class ValidatingInterceptorTest {
|
||||
.setPniServiceIdentifier(ServiceIdentifier.newBuilder()
|
||||
.setIdentityType(IdentityType.IDENTITY_TYPE_PNI)
|
||||
.setUuid(UUIDUtil.toByteString(UUID.randomUUID()))
|
||||
.build());
|
||||
.build())
|
||||
.addEachCombined(ByteString.copyFrom(new byte[4]));
|
||||
}
|
||||
|
||||
private static void assertStatusException(final Status expected, final Executable serviceCall) {
|
||||
|
||||
@ -84,7 +84,16 @@ message ValidationsRequest {
|
||||
optional common.ServiceIdentifier aciServiceIdentifier = 34 [(require.identityType) = IDENTITY_TYPE_ACI];
|
||||
optional common.ServiceIdentifier pniServiceIdentifier = 35 [(require.identityType) = IDENTITY_TYPE_PNI];
|
||||
|
||||
// next 36
|
||||
repeated bytes eachExactlyBytes = 36 [(require.each) = { exactlySize: 4 }];
|
||||
repeated uint32 eachRange = 37 [(require.each) = { range: { min: 1, max: 127 } }];
|
||||
repeated string eachNonEmpty = 38 [(require.each) = { nonEmpty: true }];
|
||||
// collection-level size combined with per-element size
|
||||
repeated bytes eachCombined = 39 [
|
||||
(require.size) = { min: 1, max: 3 },
|
||||
(require.each) = { exactlySize: 4 }
|
||||
];
|
||||
|
||||
// next 40
|
||||
}
|
||||
|
||||
message NestedMessage {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user