Add omnibus H2 server and update to dropwizard 5.0.1

This commit is contained in:
Ravi Khadiwala 2026-05-18 11:27:17 -05:00
parent 0f7d5d7fa4
commit 59f704b6cc
70 changed files with 3021 additions and 840 deletions

14
pom.xml
View File

@ -49,7 +49,7 @@
<braintree.version>3.48.0</braintree.version>
<commons-csv.version>1.14.1</commons-csv.version>
<commons-io.version>2.21.0</commons-io.version>
<dropwizard.version>4.0.16</dropwizard.version>
<dropwizard.version>5.0.1</dropwizard.version>
<!-- Note: when updating FoundationDB, also include a copy of `libfdb_c.so` from the FoundationDB release at
src/main/jib/usr/lib/libfdb_c.so. We use x86_64 builds without AVX instructions enabled (i.e. FoundationDB versions
with even-numbered patch versions). Also when updating FoundationDB, make sure to update the version of FoundationDB
@ -70,13 +70,13 @@
<kotlin.version>2.3.20</kotlin.version>
<logback.version>1.5.32</logback.version>
<logback-access-common.version>2.0.12</logback-access-common.version>
<lettuce.version>6.8.2.RELEASE</lettuce.version>
<lettuce.version>7.5.1.RELEASE</lettuce.version>
<libphonenumber.version>9.0.21</libphonenumber.version>
<logstash.logback.version>8.1</logstash.logback.version>
<log4j-bom.version>2.25.4</log4j-bom.version>
<luajava.version>3.5.0</luajava.version>
<micrometer.version>1.16.4</micrometer.version>
<netty.version>4.1.127.Final</netty.version>
<netty.version>4.2.13.Final</netty.version>
<!-- Must be less than or equal to the value from Google libraries-bom which controls the protobuf runtime version.
See https://protobuf.dev/support/cross-version-runtime-guarantee/. -->
<protoc.version>4.33.2</protoc.version>
@ -253,12 +253,6 @@
<artifactId>commons-logging</artifactId>
<version>1.3.6</version>
</dependency>
<dependency>
<groupId>org.ow2.asm</groupId>
<artifactId>asm</artifactId>
<version>9.9.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.stripe</groupId>
<artifactId>stripe-java</artifactId>
@ -361,7 +355,7 @@
</dependency>
<dependency>
<groupId>org.wiremock</groupId>
<artifactId>wiremock</artifactId>
<artifactId>wiremock-jetty12</artifactId>
<version>3.13.1</version>
<scope>test</scope>
</dependency>

View File

@ -517,6 +517,7 @@ idlePrimaryDeviceReminder:
grpc:
port: 50051
websocketPort: 8080
asnTable:
s3Region: a-region

View File

@ -23,6 +23,7 @@
<opentelemetry-logback-appender-1.0.version>2.22.0-alpha</opentelemetry-logback-appender-1.0.version>
<storekit.version>4.0.0</storekit.version>
<webauthn4j.version>0.30.2.RELEASE</webauthn4j.version>
<jetty.http2-client.version>12.1.5</jetty.http2-client.version>
</properties>
<dependencies>
@ -137,6 +138,12 @@
<groupId>io.dropwizard</groupId>
<artifactId>dropwizard-jetty</artifactId>
</dependency>
<dependency>
<groupId>org.eclipse.jetty.http2</groupId>
<artifactId>jetty-http2-client-transport</artifactId>
<scope>test</scope>
<version>${jetty.http2-client.version}</version>
</dependency>
<dependency>
<groupId>io.dropwizard</groupId>
<artifactId>dropwizard-validation</artifactId>
@ -242,15 +249,15 @@
<dependency>
<groupId>org.eclipse.jetty.websocket</groupId>
<artifactId>websocket-jetty-api</artifactId>
<artifactId>jetty-websocket-jetty-api</artifactId>
</dependency>
<dependency>
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-servlets</artifactId>
<groupId>org.eclipse.jetty.ee10</groupId>
<artifactId>jetty-ee10-servlets</artifactId>
</dependency>
<dependency>
<groupId>org.eclipse.jetty.websocket</groupId>
<artifactId>websocket-jetty-client</artifactId>
<artifactId>jetty-websocket-jetty-client</artifactId>
<scope>test</scope>
</dependency>

View File

@ -31,21 +31,25 @@ import io.lettuce.core.metrics.MicrometerOptions;
import io.lettuce.core.resource.ClientResources;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalServerChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioDatagramChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.ssl.SslContext;
import io.netty.resolver.ResolvedAddressTypes;
import io.netty.resolver.dns.DnsNameResolver;
import io.netty.resolver.dns.DnsNameResolverBuilder;
import io.netty.util.Mapping;
import jakarta.servlet.DispatcherType;
import jakarta.servlet.Filter;
import jakarta.servlet.ServletRegistration;
import java.io.ByteArrayInputStream;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.http.HttpClient;
import java.nio.charset.StandardCharsets;
import java.time.Clock;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.List;
@ -60,9 +64,9 @@ import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.function.Function;
import java.util.stream.Stream;
import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.eclipse.jetty.websocket.core.WebSocketExtensionRegistry;
import org.eclipse.jetty.websocket.core.server.WebSocketServerComponents;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.glassfish.jersey.server.ServerProperties;
import org.signal.i18n.HeaderControlledResourceBundleLookup;
import org.signal.libsignal.zkgroup.GenericServerSecretParams;
@ -135,10 +139,12 @@ import org.whispersystems.textsecuregcm.currency.CurrencyConversionManager;
import org.whispersystems.textsecuregcm.currency.FixerClient;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.filters.ExternalRequestFilter;
import org.whispersystems.textsecuregcm.filters.PriorityFilter;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.filters.RemoteDeprecationFilter;
import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter;
import org.whispersystems.textsecuregcm.filters.RestDeprecationFilter;
import org.whispersystems.textsecuregcm.filters.StripContentLengthOnConnectFilter;
import org.whispersystems.textsecuregcm.filters.TimestampResponseFilter;
import org.whispersystems.textsecuregcm.grpc.AccountsAnonymousGrpcService;
import org.whispersystems.textsecuregcm.grpc.AccountsGrpcService;
@ -164,7 +170,10 @@ import org.whispersystems.textsecuregcm.grpc.ProfileGrpcService;
import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.grpc.ValidatingInterceptor;
import org.whispersystems.textsecuregcm.grpc.net.ManagedGrpcServer;
import org.whispersystems.textsecuregcm.grpc.net.ManagedNioEventLoopGroup;
import org.whispersystems.textsecuregcm.grpc.net.ManagedEventLoopGroup;
import org.whispersystems.textsecuregcm.grpc.net.OmnibusH2Server;
import org.whispersystems.textsecuregcm.grpc.net.OmnibusRouter;
import org.whispersystems.textsecuregcm.grpc.net.SniMapper;
import org.whispersystems.textsecuregcm.jetty.JettyHttpConfigurationCustomizer;
import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceClient;
import org.whispersystems.textsecuregcm.limits.CardinalityEstimator;
@ -193,7 +202,7 @@ import org.whispersystems.textsecuregcm.metrics.BackupMetrics;
import org.whispersystems.textsecuregcm.metrics.CallQualitySurveyManager;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.metrics.MetricsApplicationEventListener;
import org.whispersystems.textsecuregcm.metrics.MetricsHttpChannelListener;
import org.whispersystems.textsecuregcm.metrics.MetricsHttpEventHandler;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.MicrometerAwsSdkMetricPublisher;
import org.whispersystems.textsecuregcm.metrics.ReportedMessageMetricsListener;
@ -314,6 +323,7 @@ import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import javax.annotation.Nullable;
public class WhisperServerService extends Application<WhisperServerConfiguration> {
@ -617,10 +627,9 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
ScheduledExecutorService cloudflareTurnRetryExecutor = ScheduledExecutorServiceBuilder.of(environment, "cloudflareTurnRetry").threads(1).build();
ScheduledExecutorService messagePollExecutor = ScheduledExecutorServiceBuilder.of(environment, "messagePollExecutor").threads(1).build();
ScheduledExecutorService provisioningWebsocketTimeoutExecutor = ScheduledExecutorServiceBuilder.of(environment, "provisioningWebsocketTimeout").threads(1).build();
ScheduledExecutorService jmxDumper = ScheduledExecutorServiceBuilder.of(environment, "jmxDumper").threads(1).build();
final ManagedNioEventLoopGroup dnsResolutionEventLoopGroup = new ManagedNioEventLoopGroup();
final DnsNameResolver cloudflareDnsResolver = new DnsNameResolverBuilder(dnsResolutionEventLoopGroup.next())
final ManagedEventLoopGroup<NioEventLoopGroup> dnsResolutionEventLoopGroup = new ManagedEventLoopGroup<>(new NioEventLoopGroup());
final DnsNameResolver cloudflareDnsResolver = new DnsNameResolverBuilder(dnsResolutionEventLoopGroup.getEventLoopGroup().next())
.resolvedAddressTypes(ResolvedAddressTypes.IPV6_PREFERRED)
.completeOncePreferredResolved(false)
.channelType(NioDatagramChannel.class)
@ -1013,24 +1022,38 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
prohibitAuthenticationInterceptor))
.toList();
final ServerBuilder<?> serverBuilder =
NettyServerBuilder.forAddress(new InetSocketAddress(config.getGrpc().bindAddress(), config.getGrpc().port()));
final ManagedEventLoopGroup<DefaultEventLoopGroup> omnibusLocalEventLoopGroup = new ManagedEventLoopGroup<>(new DefaultEventLoopGroup());
final ManagedEventLoopGroup<NioEventLoopGroup> omnibusNioEventLoopGroup = new ManagedEventLoopGroup<>(new NioEventLoopGroup());
final LocalAddress grpcLocalAddress = new LocalAddress("grpc");
final ServerBuilder<?> serverBuilder = NettyServerBuilder
.forAddress(grpcLocalAddress)
.channelType(LocalServerChannel.class)
.bossEventLoopGroup(omnibusLocalEventLoopGroup.getEventLoopGroup())
.workerEventLoopGroup(omnibusLocalEventLoopGroup.getEventLoopGroup());
authenticatedServices.forEach(serverBuilder::addService);
unauthenticatedServices.forEach(serverBuilder::addService);
final ManagedGrpcServer exposedGrpcServer = new ManagedGrpcServer(serverBuilder.build());
final ManagedGrpcServer localGrpcServer = new ManagedGrpcServer(serverBuilder.build());
environment.lifecycle().manage(exposedGrpcServer);
final SocketAddress websocketAddress =
new InetSocketAddress(config.getGrpc().websocketAddress(), config.getGrpc().websocketPort());
final OmnibusRouter omnibusRouter = new OmnibusRouter(List.of(
new OmnibusRouter.OmnibusRoute("/v1/websocket", websocketAddress),
new OmnibusRouter.OmnibusRoute("/v1/provisioning", websocketAddress)),
grpcLocalAddress);
@Nullable final Mapping<String, SslContext> sniMapping = config.getGrpc().h2c()
? null
: SniMapper.buildSniMapping(config.getTlsKeyStoreConfiguration().path(), config.getTlsKeyStoreConfiguration().password().value());
final OmnibusH2Server omnibusH2Server = new OmnibusH2Server(
sniMapping,
omnibusNioEventLoopGroup.getEventLoopGroup(),
omnibusLocalEventLoopGroup.getEventLoopGroup(),
new InetSocketAddress(config.getGrpc().bindAddress(), config.getGrpc().port()), omnibusRouter,
config.getGrpc().idleTimeout());
final List<Filter> filters = new ArrayList<>();
filters.add(remoteDeprecationFilter);
filters.add(new RemoteAddressFilter());
filters.add(new TimestampResponseFilter());
for (Filter filter : filters) {
environment.servlets()
.addFilter(filter.getClass().getSimpleName(), filter)
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
}
environment.lifecycle().manage(omnibusLocalEventLoopGroup);
environment.lifecycle().manage(omnibusNioEventLoopGroup);
environment.lifecycle().manage(localGrpcServer);
environment.lifecycle().manage(omnibusH2Server);
if (!config.getExternalRequestFilterConfiguration().paths().isEmpty()) {
environment.servlets().addFilter(ExternalRequestFilter.class.getSimpleName(),
@ -1048,9 +1071,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
final String websocketServletPath = "/v1/websocket/";
final String provisioningWebsocketServletPath = "/v1/websocket/provisioning/";
final MetricsHttpChannelListener metricsHttpChannelListener = new MetricsHttpChannelListener(clientReleaseManager,
Set.of(websocketServletPath, provisioningWebsocketServletPath, "/health-check"));
metricsHttpChannelListener.configure(environment);
MetricsHttpEventHandler.configure(environment, Metrics.globalRegistry, clientReleaseManager, Set.of(websocketServletPath, provisioningWebsocketServletPath, "/health-check"));
final MessageMetrics messageMetrics = new MessageMetrics();
// BufferingInterceptor is needed on the base environment but not the WebSocketEnvironment,
@ -1174,36 +1195,36 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
webSocketEnvironment.jersey().property(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE);
provisioningEnvironment.jersey().property(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), (context, container) -> {
final WebSocketExtensionRegistry extensionRegistry = WebSocketServerComponents
.getWebSocketComponents(environment.getApplicationContext().getServletContext())
.getExtensionRegistry();
if (config.getWebSocketConfiguration().isDisablePerMessageDeflate()) {
extensionRegistry.unregister("permessage-deflate");
} else if (config.getWebSocketConfiguration().isDisableCrossMessageOutgoingCompression()) {
extensionRegistry.unregister("permessage-deflate");
extensionRegistry.register("permessage-deflate", NoContextTakeoverPerMessageDeflateExtension.class);
}
});
WebSocketResourceProviderFactory<AuthenticatedDevice> webSocketServlet = new WebSocketResourceProviderFactory<>(
webSocketEnvironment, AuthenticatedDevice.class, config.getWebSocketConfiguration(),
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
webSocketEnvironment, AuthenticatedDevice.class, RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
WebSocketResourceProviderFactory<AuthenticatedDevice> provisioningServlet = new WebSocketResourceProviderFactory<>(
provisioningEnvironment, AuthenticatedDevice.class, config.getWebSocketConfiguration(),
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
provisioningEnvironment, AuthenticatedDevice.class, RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
ServletRegistration.Dynamic websocket = environment.servlets().addServlet("WebSocket", webSocketServlet);
ServletRegistration.Dynamic provisioning = environment.servlets().addServlet("Provisioning", provisioningServlet);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(),
(servletContext, container) -> {
container.addMapping(websocketServletPath, webSocketServlet);
container.addMapping(provisioningWebsocketServletPath, provisioningServlet);
websocket.addMapping(websocketServletPath);
websocket.setAsyncSupported(true);
PriorityFilter.ensureFilter(servletContext, new StripContentLengthOnConnectFilter());
PriorityFilter.ensureFilter(servletContext, new TimestampResponseFilter());
PriorityFilter.ensureFilter(servletContext, new RemoteAddressFilter());
PriorityFilter.ensureFilter(servletContext, remoteDeprecationFilter);
provisioning.addMapping(provisioningWebsocketServletPath);
provisioning.setAsyncSupported(true);
container.setMaxBinaryMessageSize(config.getWebSocketConfiguration().getMaxBinaryMessageSize());
container.setMaxTextMessageSize(config.getWebSocketConfiguration().getMaxTextMessageSize());
final WebSocketExtensionRegistry extensionRegistry = WebSocketServerComponents
.getWebSocketComponents(environment.getApplicationContext())
.getExtensionRegistry();
if (config.getWebSocketConfiguration().isDisablePerMessageDeflate()) {
extensionRegistry.unregister("permessage-deflate");
} else if (config.getWebSocketConfiguration().isDisableCrossMessageOutgoingCompression()) {
extensionRegistry.unregister("permessage-deflate");
extensionRegistry.register("permessage-deflate", NoContextTakeoverPerMessageDeflateExtension.class);
}
});
environment.admin().addTask(new SetRequestLoggingEnabledTask());
}
private void registerExceptionMappers(Environment environment,

View File

@ -10,10 +10,9 @@ import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Optional;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeResponse;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.websocket.auth.AuthenticatedWebSocketUpgradeFilter;

View File

@ -5,11 +5,33 @@
package org.whispersystems.textsecuregcm.configuration;
import jakarta.validation.constraints.NotNull;
import java.time.Duration;
/// Configuration for the gRPC Server
///
/// @param bindAddress The host to bind the omnibus server to
/// @param port The port to bind the omnibus server to
/// @param websocketAddress The address of a listening websocket server for handling legacy requests
/// @param websocketPort The port of a listening websocket server for handling legacy requests
/// @param idleTimeout The duration after which an idle connection may be disconnected
/// @param h2c If true, listen for plaintext h2c with prior-knowledge
public record GrpcConfiguration(
@NotNull String bindAddress,
@NotNull Integer port,
@NotNull String websocketAddress,
@NotNull Integer websocketPort,
@NotNull Duration idleTimeout,
boolean h2c) {
public record GrpcConfiguration(@NotNull String bindAddress, @NotNull Integer port) {
public GrpcConfiguration {
if (bindAddress == null || bindAddress.isEmpty()) {
bindAddress = "localhost";
}
if (websocketAddress == null || websocketAddress.isEmpty()) {
websocketAddress = "localhost";
}
if (idleTimeout == null) {
idleTimeout = Duration.ofMinutes(5);
}
}
}

View File

@ -7,6 +7,9 @@ package org.whispersystems.textsecuregcm.configuration;
import jakarta.validation.constraints.NotNull;
import org.whispersystems.textsecuregcm.configuration.secrets.SecretString;
import javax.annotation.Nullable;
public record TlsKeyStoreConfiguration(@NotNull SecretString password) {
public record TlsKeyStoreConfiguration(
@Nullable String path,
@NotNull SecretString password) {
}

View File

@ -0,0 +1,82 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.filters;
import jakarta.servlet.DispatcherType;
import jakarta.servlet.Filter;
import jakarta.servlet.ServletContext;
import java.util.EnumSet;
import java.util.Objects;
import org.eclipse.jetty.ee10.servlet.FilterHolder;
import org.eclipse.jetty.ee10.servlet.FilterMapping;
import org.eclipse.jetty.ee10.servlet.ServletContextHandler;
import org.eclipse.jetty.ee10.servlet.ServletHandler;
import org.eclipse.jetty.server.handler.ContextHandler;
import org.eclipse.jetty.util.component.LifeCycle;
public class PriorityFilter {
private PriorityFilter() {}
private static FilterHolder getFilter(ServletContext servletContext, final Class<? extends Filter> filterClass) {
final ContextHandler contextHandler = Objects.requireNonNull(ServletContextHandler.getServletContextHandler(servletContext));
final ServletHandler servletHandler = contextHandler.getDescendant(ServletHandler.class);
return servletHandler.getFilter(filterClass.getName());
}
/**
* Ensure a filter is available on the provided ServletContext, a new filter will added if one does not already
* exist.
* <p>
* If a new filter is added, it will be added before all other filters.
* <p>
* Modeled after {@link org.eclipse.jetty.ee10.websocket.servlet.WebSocketUpgradeFilter#ensureFilter(ServletContext)},
* since its use of {@link org.eclipse.jetty.ee10.servlet.ServletHandler#prependFilter(FilterHolder)} is what makes
* this necessary.
*/
public static void ensureFilter(final ServletContext servletContext, final Filter filter) {
FilterHolder existingFilter = getFilter(servletContext, filter.getClass());
if (existingFilter != null) {
return;
}
final ContextHandler contextHandler = ServletContextHandler.getServletContextHandler(servletContext);
final ServletHandler servletHandler = contextHandler.getDescendant(ServletHandler.class);
final String pathSpec = "/*";
final FilterHolder holder = new FilterHolder(filter);
holder.setName(filter.getClass().getName());
holder.setAsyncSupported(true);
final FilterMapping mapping = new FilterMapping();
mapping.setFilterName(holder.getName());
mapping.setPathSpec(pathSpec);
mapping.setDispatcherTypes(EnumSet.of(DispatcherType.REQUEST));
// Add as the first filter in the list.
servletHandler.prependFilter(holder);
servletHandler.prependFilterMapping(mapping);
// If we create the filter we must also make sure it is removed if the context is stopped.
contextHandler.addEventListener(new LifeCycle.Listener()
{
@Override
public void lifeCycleStopping(LifeCycle event)
{
servletHandler.removeFilterHolder(holder);
servletHandler.removeFilterMapping(mapping);
contextHandler.removeEventListener(this);
}
@Override
public String toString()
{
return String.format("%sCleanupListener", filter.getClass().getSimpleName());
}
});
}
}

View File

@ -26,10 +26,6 @@ public class RemoteAddressFilter implements Filter {
public static final String REMOTE_ADDRESS_ATTRIBUTE_NAME = RemoteAddressFilter.class.getName() + ".remoteAddress";
private static final Logger logger = LoggerFactory.getLogger(RemoteAddressFilter.class);
public RemoteAddressFilter() {
}
@Override
public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain)
throws ServletException, IOException {

View File

@ -0,0 +1,71 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.filters;
import jakarta.servlet.Filter;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.nio.ByteBuffer;
import org.eclipse.jetty.ee10.servlet.ServletContextRequest;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.HttpMethod;
import org.eclipse.jetty.http.HttpVersion;
import org.eclipse.jetty.http.MetaData;
import org.eclipse.jetty.server.HttpStream;
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.util.Callback;
/// Our current version of jetty (12.1.5) has a bug where it includes content-length:0 on
/// CONNECT websocket upgrade requests. Providing an HTTP/2 header frame with a
/// content-length that does not match the sum of the lengths of the data frames is technically
/// a malformed HTTP/2 stream and our netty-based reverse proxy implementation rejects it. This
/// filter strips out the superfluous content-length at stream-send time. It can be removed once
/// we update to a jetty version that fixes [jetty/jetty.project#15074](https://github.com/jetty/jetty.project/issues/15074)
public class StripContentLengthOnConnectFilter implements Filter {
@Override
public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain)
throws IOException, ServletException {
if (request instanceof HttpServletRequest hsr &&
HttpVersion.HTTP_2.is(hsr.getProtocol()) &&
HttpMethod.CONNECT.is(hsr.getMethod())) {
final Request coreRequest = ServletContextRequest.getServletContextRequest(hsr);
if (coreRequest != null) {
coreRequest.addHttpStreamWrapper(StripContentLengthStream::new);
}
}
chain.doFilter(request, response);
}
private static class StripContentLengthStream extends HttpStream.Wrapper {
StripContentLengthStream(final HttpStream wrapped) {
super(wrapped);
}
@Override
public void send(MetaData.Request request, MetaData.Response response, boolean last, ByteBuffer content,
Callback callback) {
if (response != null && response.getStatus() == 200 && response.getHttpFields()
.contains(HttpHeader.CONTENT_LENGTH)) {
final HttpFields fieldsWithoutContentLengthHeader =
HttpFields.build(response.getHttpFields()).remove(HttpHeader.CONTENT_LENGTH);
response = new MetaData.Response(
response.getStatus(),
response.getReason(),
response.getHttpVersion(),
fieldsWithoutContentLengthHeader,
-1,
response.getTrailersSupplier());
}
super.send(request, response, last, content, callback);
}
}
}

View File

@ -48,6 +48,8 @@ public class RequestAttributesInterceptor implements ServerInterceptor {
final String acceptLanguageHeader = headers.get(ACCEPT_LANG_KEY);
final String xForwardedForHeader = headers.get(X_FORWARDED_FOR_KEY);
// This assumes that X-Forwarded-For has been set by a trusted intermediate proxy. For example, this may be set by
// OmnibusH2Server which itself sets X-Forwarded-For using a PPv2 header that comes from a trusted load-balancer.
final Optional<InetAddress> remoteAddress = getMostRecentProxy(xForwardedForHeader)
.flatMap(mostRecentProxy -> {
try {

View File

@ -0,0 +1,91 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import io.micrometer.core.instrument.Metrics;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http2.Http2StreamFrame;
import io.netty.util.ReferenceCountUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
/// Writes all inbound H2 frames to [this#peerStream], renumbering the inbound H2 stream-id for peer H2 stream
public class H2FrameProxyHandler extends ChannelInboundHandlerAdapter {
private static final Logger logger = LoggerFactory.getLogger(H2FrameProxyHandler.class);
private static final String WRITABILITY_CHANGED_COUNTER_NAME = MetricsUtil.name(H2FrameProxyHandler.class, "writabilityChanged");
private final Channel peerStream;
private final String proxyNameTag;
// If we fail to write to the peerStream, we want to close the inbound channel. Rather than allocate a new listener
// that captures the inbound ChannelHandlerContext on every message, we capture the ChannelHandlerContext in
// handlerAdded and use it on all forwarded writes. This would not work if we attached this handler to more than
// one channel, but we already have a designated peerStream so this handler is fundamentally single-channel.
private ChannelFutureListener closeInboundOnPeerFailure = null;
public H2FrameProxyHandler(final Channel peerStream, final String proxyNameTag) {
this.peerStream = peerStream;
this.proxyNameTag = proxyNameTag;
}
@Override
public void handlerAdded(final ChannelHandlerContext ctx) {
closeInboundOnPeerFailure = f -> {
if (!f.isSuccess()) {
ctx.close();
}
};
// When the peer stream we are forwarding to becomes unwritable/writable, stop/start reading from the source stream.
// This prevents us from reading from the source stream as fast as we can just to buffer requests for the peer
// stream.
peerStream.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override
public void channelWritabilityChanged(final ChannelHandlerContext peerCtx) throws Exception {
Metrics.counter(WRITABILITY_CHANGED_COUNTER_NAME,
"isWritable", Boolean.toString(peerCtx.channel().isWritable()),
"proxy", proxyNameTag)
.increment();
ctx.channel().config().setAutoRead(peerStream.isWritable());
super.channelWritabilityChanged(peerCtx);
}
});
}
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
logger.trace("Received frame {}", msg);
if (!(msg instanceof Http2StreamFrame streamFrame)) {
logger.error("Received unexpected frame {}", msg);
ReferenceCountUtil.release(msg);
ctx.close();
return;
}
// Clear the stream-id on this frame, so netty will associate it with the peerStream's stream-id. The inbound
// frame has a stream-id associated with the inbound connection. This will not match the stream-id of the peer
// stream we are forwarding the frames to. If the stream-id on a frame is not set, netty handles sending the
// stream-id on the frame to the target stream's stream-id.
streamFrame.stream(null);
peerStream.writeAndFlush(streamFrame).addListener(closeInboundOnPeerFailure);
}
@Override
public void channelInactive(final ChannelHandlerContext ctx) {
peerStream.close();
}
@Override
public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) {
logger.warn("Exception proxying frames", cause);
ctx.close();
}
}

View File

@ -0,0 +1,30 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import io.dropwizard.lifecycle.Managed;
import io.netty.channel.EventLoopGroup;
/**
* A wrapper for a Netty {@link EventLoopGroup} that implements Dropwizard's {@link Managed} interface, allowing
* Dropwizard to manage the lifecycle of the event loop group.
*/
public class ManagedEventLoopGroup<T extends EventLoopGroup> implements Managed {
private final T eventLoopGroup;
public ManagedEventLoopGroup(final T eventLoopGroup) {
this.eventLoopGroup = eventLoopGroup;
}
@Override
public void stop() throws Exception {
this.eventLoopGroup.shutdownGracefully().await();
}
public T getEventLoopGroup() {
return eventLoopGroup;
}
}

View File

@ -1,16 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.dropwizard.lifecycle.Managed;
import io.netty.channel.nio.NioEventLoopGroup;
/**
* A wrapper for a Netty {@link NioEventLoopGroup} that implements Dropwizard's {@link Managed} interface, allowing
* Dropwizard to manage the lifecycle of the event loop group.
*/
public class ManagedNioEventLoopGroup extends NioEventLoopGroup implements Managed {
@Override
public void stop() throws Exception {
this.shutdownGracefully().await();
}
}

View File

@ -0,0 +1,38 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import java.util.concurrent.atomic.AtomicLong;
@ChannelHandler.Sharable
public class OmnibusConnectionCounterHandler extends ChannelInboundHandlerAdapter {
private final AtomicLong openConnections;
private final Counter acceptedConnectionsCounter =
Metrics.counter(MetricsUtil.name(OmnibusConnectionCounterHandler.class, "connectionsAccepted"));
public OmnibusConnectionCounterHandler() {
openConnections =
Metrics.gauge(MetricsUtil.name(OmnibusConnectionCounterHandler.class, "openConnections"), new AtomicLong());
}
@Override
public void channelRegistered(ChannelHandlerContext ctx) throws Exception {
acceptedConnectionsCounter.increment();
openConnections.incrementAndGet();
super.channelRegistered(ctx);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
openConnections.decrementAndGet();
super.channelInactive(ctx);
}
}

View File

@ -0,0 +1,55 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import io.micrometer.core.instrument.Metrics;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
/// A handler that closes the channel on an exception and records errors in a counter. This should be placed at the tail
/// of pipelines to catch uncaught exceptions gracefully
@ChannelHandler.Sharable
public class OmnibusExceptionHandler extends ChannelInboundHandlerAdapter {
private static final Logger logger = LoggerFactory.getLogger(OmnibusExceptionHandler.class);
private static final String UNCAUGHT_EXCEPTION_COUNTER_NAME = MetricsUtil.name(OmnibusExceptionHandler.class,
"uncaughtException");
private final String channelName;
private final List<Class<? extends Exception>> expectedExceptions;
public OmnibusExceptionHandler(final String channelName, final List<Class<? extends Exception>> expectedExceptions) {
this.channelName = channelName;
this.expectedExceptions = expectedExceptions;
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
Metrics.counter(UNCAUGHT_EXCEPTION_COUNTER_NAME,
"channelName", channelName,
"exceptionClass", cause.getClass().getSimpleName())
.increment();
// There are 'expected' ways to get exceptions on a channel (e.g. client disconnects) so we only log them at debug.
if (expectedException(cause)) {
logger.debug("uncaught exception on channel {}", channelName, cause);
} else {
logger.warn("unexpected uncaught exception on channel {}", channelName, cause);
}
ctx.close();
}
private boolean expectedException(final Throwable exception) {
return expectedExceptions
.stream()
.anyMatch(expectedException -> expectedException.isInstance(exception));
}
}

View File

@ -0,0 +1,176 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.lifecycle.Managed;
import io.micrometer.core.instrument.Metrics;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.ByteBufAllocatorMetricProvider;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.SimpleUserEventChannelHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.http2.Http2Exception;
import io.netty.handler.codec.http2.Http2FrameCodecBuilder;
import io.netty.handler.codec.http2.Http2MultiplexHandler;
import io.netty.handler.codec.http2.Http2Settings;
import io.netty.handler.codec.http2.Http2StreamChannel;
import io.netty.handler.ssl.ApplicationProtocolNames;
import io.netty.handler.ssl.ApplicationProtocolNegotiationHandler;
import io.netty.handler.ssl.SniHandler;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.timeout.IdleStateEvent;
import io.netty.handler.timeout.IdleStateHandler;
import io.netty.util.Mapping;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.SocketException;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
/// An HTTP/2 server that proxies H2 streams to configurable backends via path-based routing
public class OmnibusH2Server implements Managed {
private static final Logger logger = LoggerFactory.getLogger(OmnibusH2Server.class);
private static final OmnibusConnectionCounterHandler CONNECT_COUNTER = new OmnibusConnectionCounterHandler();
private static final OmnibusExceptionHandler HANDSHAKE_EXCEPTION_HANDLER =
new OmnibusExceptionHandler("omnibus-handshake", List.of(SocketException.class, DecoderException.class, IOException.class));
private static final OmnibusExceptionHandler SESSION_EXCEPTION_HANDLER =
new OmnibusExceptionHandler("omnibus-session", List.of(Http2Exception.class));
private static final String IDLE_DISCONNECT_COUNTER_NAME = MetricsUtil.name(OmnibusH2Server.class, "idleDisconnect");
private final @Nullable Mapping<String, SslContext> sslContextBySni;
private final OmnibusRouter router;
private final Duration idleTimeout;
private final DefaultEventLoopGroup localEventLoopGroup;
private final NioEventLoopGroup nioEventLoopGroup;
private final SocketAddress bindAddress;
private Channel serverChannel;
/// Create an omnibus server
///
/// @param sslContextBySni If not null, a mapping between domain (SNI) and the appropriate SslContext to use for
/// that SNI. If null, the server will not include TLS (h2c with prior-knowledge)
/// @param nioEventLoopGroup Event loop to use for all NIO channel pipelines
/// @param localEventLoopGroup Event loop to use for all local channel pipelines
/// @param bindAddress The address the server should listen on
/// @param router How the server should select backends based on request paths
public OmnibusH2Server(
final @Nullable Mapping<String, SslContext> sslContextBySni,
final NioEventLoopGroup nioEventLoopGroup,
final DefaultEventLoopGroup localEventLoopGroup,
final SocketAddress bindAddress,
final OmnibusRouter router,
final Duration idleTimeout) {
this.sslContextBySni = sslContextBySni;
this.nioEventLoopGroup = nioEventLoopGroup;
this.localEventLoopGroup = localEventLoopGroup;
this.bindAddress = bindAddress;
this.router = router;
this.idleTimeout = idleTimeout;
}
@Override
public void start() throws Exception {
if (this.sslContextBySni == null) {
logger.warn("No SSL configuration provided for OmnibusH2Server, serving h2c");
}
if (ByteBufAllocator.DEFAULT instanceof ByteBufAllocatorMetricProvider alloc) {
Metrics.gauge(MetricsUtil.name(OmnibusH2Server.class, "nettyUsedDirectMemory"),
alloc,
allocator -> allocator.metric().usedDirectMemory());
Metrics.gauge(MetricsUtil.name(OmnibusH2Server.class, "nettyUsedHeapMemory"),
alloc,
allocator -> allocator.metric().usedHeapMemory());
}
final ServerBootstrap bootstrap = new ServerBootstrap()
.group(nioEventLoopGroup)
.channel(NioServerSocketChannel.class)
.childOption(ChannelOption.SO_KEEPALIVE, true)
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(final SocketChannel ch) {
ch.pipeline().addLast(new IdleStateHandler(0, 0, idleTimeout.toMillis(), TimeUnit.MILLISECONDS));
ch.pipeline().addLast(new SimpleUserEventChannelHandler<IdleStateEvent>() {
@Override
protected void eventReceived(final ChannelHandlerContext ctx, final IdleStateEvent evt) {
Metrics.counter(IDLE_DISCONNECT_COUNTER_NAME, "type", evt.state().name()).increment();
ctx.close();
}
});
ch.pipeline().addLast(CONNECT_COUNTER);
ch.pipeline().addLast(new ProxyProtocolHandler());
ch.pipeline().addLast(new ProxyMessageAttributeSetterHandler());
if (sslContextBySni == null) {
configureH2Pipeline(ch.pipeline());
} else {
ch.pipeline().addLast(new SniHandler(sslContextBySni));
ch.pipeline().addLast(new ApplicationProtocolNegotiationHandler(ApplicationProtocolNames.HTTP_2) {
@Override
protected void configurePipeline(final ChannelHandlerContext ctx, final String protocol) {
if (!ApplicationProtocolNames.HTTP_2.equals(protocol)) {
// HTTP/2 should be enforced by our ALPN settings
logger.error("Unsupported protocol negotiated: {}, closing connection", protocol);
ctx.close();
return;
}
configureH2Pipeline(ctx.pipeline());
}
});
ch.pipeline().addLast(HANDSHAKE_EXCEPTION_HANDLER);
}
}
});
serverChannel = bootstrap.bind(bindAddress).sync().channel();
logger.info("Omnibus server listening on {}", getLocalAddress());
}
@VisibleForTesting
InetSocketAddress getLocalAddress() {
return (InetSocketAddress) serverChannel.localAddress();
}
@Override
public void stop() {
if (serverChannel != null) {
logger.info("Stopping omnibus server");
serverChannel.close().syncUninterruptibly();
logger.info("Omnibus server stopped");
}
}
private void configureH2Pipeline(final ChannelPipeline pipeline) {
// Advertise support for RFC-8441 extended connect
final Http2Settings settings = Http2Settings.defaultSettings().connectProtocolEnabled(true);
pipeline.addLast(Http2FrameCodecBuilder.forServer().initialSettings(settings).build());
pipeline.addLast(new Http2MultiplexHandler(new ChannelInitializer<Http2StreamChannel>() {
@Override
protected void initChannel(final Http2StreamChannel ch) {
ch.pipeline().addLast(new OmnibusH2StreamHandler(nioEventLoopGroup, localEventLoopGroup, router));
}
}));
pipeline.addLast(SESSION_EXCEPTION_HANDLER);
}
}

View File

@ -0,0 +1,230 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http2.DefaultHttp2Headers;
import io.netty.handler.codec.http2.DefaultHttp2HeadersFrame;
import io.netty.handler.codec.http2.DefaultHttp2ResetFrame;
import io.netty.handler.codec.http2.Http2Error;
import io.netty.handler.codec.http2.Http2FrameCodecBuilder;
import io.netty.handler.codec.http2.Http2HeadersFrame;
import io.netty.handler.codec.http2.Http2MultiplexHandler;
import io.netty.handler.codec.http2.Http2Settings;
import io.netty.handler.codec.http2.Http2StreamChannel;
import io.netty.handler.codec.http2.Http2StreamChannelBootstrap;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.List;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/// Handler added on each newly created [Http2StreamChannel] on an H2 connection. Inspects the [Http2HeadersFrame] and
/// determines which backend to forward the stream to, and then proxies frames to and from the backend.
///
/// When this handler receives an H2 header for a new stream-id=X on our parent H2 connection it will
/// - Receive the stream-id=X header and check the path to determine the correct backend
/// - Make a new H2 connection to the backend
/// - Forward the header with stream-id=Y to the backend
/// - Install a [H2FrameProxyHandler] on the backend stream pipeline that forwards the received stream-id=Y frames from
/// the backend back to the client on stream-id=X
/// - Install a [H2FrameProxyHandler] on the client stream pipeline forwards the received stream-id=X frames from the
/// client to the backend on stream-id=Y
public class OmnibusH2StreamHandler extends ChannelInboundHandlerAdapter {
private static final Logger logger = LoggerFactory.getLogger(OmnibusH2StreamHandler.class);
private static final OmnibusExceptionHandler BACKEND_CONNECTION_EXCEPTION_HANDLER =
new OmnibusExceptionHandler("backend-connection", List.of());
private static final String BACKEND_STREAM_COUNTER_NAME = name(OmnibusH2StreamHandler.class, "backendStream");
private static final String BACKEND_CONNECT_DURATION_NAME = name(OmnibusH2StreamHandler.class,
"backendConnectDuration");
private static final String BACKEND_TAG = "backend";
private final OmnibusRouter router;
private final DefaultEventLoopGroup localEventLoopGroup;
private final NioEventLoopGroup nioEventLoopGroup;
public OmnibusH2StreamHandler(
final NioEventLoopGroup nioEventLoopGroup,
final DefaultEventLoopGroup localEventLoopGroup,
final OmnibusRouter router) {
this.router = router;
this.localEventLoopGroup = localEventLoopGroup;
this.nioEventLoopGroup = nioEventLoopGroup;
}
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
if (!(msg instanceof Http2HeadersFrame headersFrame)) {
logger.warn("Expected initial HEADERS frame but got {}", msg.getClass().getSimpleName());
ReferenceCountUtil.release(msg);
ctx.close();
return;
}
// We don't expect headers frames to come with manually managed memory attached. Assert this in case this changes
// in the future, but for now we don't have to worry about freeing the headers frame
assert !(headersFrame instanceof ReferenceCounted);
// Disable reading from the client because we want to wait until we make the backend connection and install the
// forwarding handler before processing any more frames.
ctx.channel().config().setAutoRead(false);
// Select the target backend based on the path
final String path = Optional.ofNullable(headersFrame.headers().path()).map(CharSequence::toString).orElse("");
final SocketAddress target = router.match(path);
final String backendTag = target.toString();
Metrics.counter(BACKEND_STREAM_COUNTER_NAME, BACKEND_TAG, backendTag).increment();
// Set X-Forwarded-For from the PROXY protocol header if present, otherwise via the remote address
final InetAddress proxyRemoteAddress = ctx.channel().parent()
.attr(ProxyMessageAttributeSetterHandler.PROXY_REMOTE_ADDRESS)
.get();
headersFrame.headers().set("x-forwarded-for", proxyRemoteAddress != null
? proxyRemoteAddress.getHostAddress()
: ((InetSocketAddress) ctx.channel().remoteAddress()).getHostString());
// Make a new H2 connection to the target backend
final Timer.Sample connectSample = Timer.start();
new Bootstrap()
.group(selectEventLoop(ctx, target))
.channel(target instanceof LocalAddress ? LocalChannel.class : NioSocketChannel.class)
.handler(new ChannelInitializer<>() {
@Override
protected void initChannel(final Channel ch) {
ch.pipeline()
.addLast(Http2FrameCodecBuilder.forClient().initialSettings(Http2Settings.defaultSettings()).build());
// Http2MultiplexHandler takes handler that is added to new inbound streams. A client Http2MultiplexHandler
// like we're defining here should never receive an inbound H2 stream so we can just pass a noop handler
ch.pipeline().addLast(new Http2MultiplexHandler(new NoopInboundStreamHandler()));
ch.pipeline().addLast(BACKEND_CONNECTION_EXCEPTION_HANDLER);
}
})
.connect(target)
.addListener((ChannelFuture connectFuture) -> {
connectSample.stop(Timer.builder(BACKEND_CONNECT_DURATION_NAME)
.tag(BACKEND_TAG, backendTag)
.tag("outcome", connectFuture.isSuccess() ? "success" : "failure")
.register(Metrics.globalRegistry));
if (!connectFuture.isSuccess()) {
// Close the client stream with a 502: Bad Gateway if the backend wasn't available
logger.warn("Failed to connect to backend {}", target, connectFuture.cause());
ctx.channel()
.writeAndFlush(new DefaultHttp2HeadersFrame(
new DefaultHttp2Headers().status("502"), true))
.addListener(ChannelFutureListener.CLOSE);
return;
}
// Connected, open a new H2 stream to the backend so we can proxy the client's frames
logger.trace("Opening a HTTP/2 stream to the backend {}", target);
final Channel backendConnection = connectFuture.channel();
createBackendProxyStream(ctx, backendConnection, headersFrame);
});
}
/// Create a proxy stream on the provided `backendConnection` that forwards H2 frames to/from the client H2 stream.
///
/// @param clientStreamCtx The context for a client H2 stream that targets the backend
/// @param backendConnection An established H2 connection [Channel], on which a new h2 stream will be opened
/// @param headersFrame The first `headersFrame` from the client h2 stream that should be forwarded to the new
/// backend stream
private void createBackendProxyStream(
final ChannelHandlerContext clientStreamCtx,
final Channel backendConnection,
final Http2HeadersFrame headersFrame) {
new Http2StreamChannelBootstrap(backendConnection)
// Forwards response frames from the backend back to the client stream
.handler(new H2FrameProxyHandler(clientStreamCtx.channel(), "responseStream"))
.open()
.addListener((io.netty.util.concurrent.Future<Http2StreamChannel> streamFuture) -> {
if (!streamFuture.isSuccess()) {
logger.warn("Failed to open backend stream", streamFuture.cause());
clientStreamCtx.channel()
.writeAndFlush(new DefaultHttp2ResetFrame(Http2Error.INTERNAL_ERROR))
.addListener(ChannelFutureListener.CLOSE);
backendConnection.close();
return;
}
final Http2StreamChannel backendStream = streamFuture.getNow();
// Close the entire H2 connection whenever the stream we just opened closes. We only plan on using
// a single stream on this connection.
backendStream.closeFuture().addListener(_ -> backendConnection.close());
// We're going to modify the inbound H2 stream channel, which runs on a different eventloop than the
// outbound channel we've made to the backend. We have to submit our updates back to the inbound
// channel's event loop for thread safety
clientStreamCtx.channel().eventLoop().execute(() -> {
if (!clientStreamCtx.channel().isActive()) {
// The client disconnected already and the client pipeline is already torn down.
backendConnection.close();
return;
}
// Install proxy on client stream, remove this handler, then fire the buffered headers through the proxy
clientStreamCtx.pipeline().replace(
OmnibusH2StreamHandler.this,
"backend-to-client-proxy",
new H2FrameProxyHandler(backendStream, "requestStream"));
clientStreamCtx.channel().pipeline().fireChannelRead(headersFrame);
// Resume inbound reads, which should now be forwarded
clientStreamCtx.channel().config().setAutoRead(true);
});
});
}
private EventLoopGroup selectEventLoop(final ChannelHandlerContext inboundCtx, SocketAddress target) {
final boolean localInbound = inboundCtx.channel() instanceof LocalChannel;
final boolean localTarget = target instanceof LocalAddress;
// If the inbound eventloop matches the target type, we can just reuse the inbound's event loop
if (localInbound == localTarget) {
return inboundCtx.channel().eventLoop();
}
return localTarget ? this.localEventLoopGroup : this.nioEventLoopGroup;
}
private static class NoopInboundStreamHandler extends ChannelInboundHandlerAdapter {
@Override
public void channelRegistered(final ChannelHandlerContext ctx) {
logger.error("Inbound stream handler was registered when no inbound streams expected");
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
logger.error("Received unexpected message: {} on inbound stream from backend", msg);
super.channelRead(ctx, msg);
}
}
}

View File

@ -0,0 +1,30 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import java.net.SocketAddress;
import java.util.List;
public class OmnibusRouter {
public record OmnibusRoute(String prefix, SocketAddress backend) {}
private final List<OmnibusRoute> prefixRoutes;
private final SocketAddress defaultBackend;
public OmnibusRouter(final List<OmnibusRoute> prefixRoutes, final SocketAddress defaultBackend) {
this.prefixRoutes = prefixRoutes;
this.defaultBackend = defaultBackend;
}
SocketAddress match(final String path) {
for (final OmnibusRoute route : prefixRoutes) {
if (path.startsWith(route.prefix)) {
return route.backend;
}
}
return defaultBackend;
}
}

View File

@ -0,0 +1,42 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.util.AttributeKey;
import java.net.InetAddress;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/// Reads the decoded [HAProxyMessage], stores the source address as a channel attribute, and removes itself.
class ProxyMessageAttributeSetterHandler extends ChannelInboundHandlerAdapter {
private static final Logger logger = LoggerFactory.getLogger(ProxyMessageAttributeSetterHandler.class);
/// Attribute for the remote address extracted from a proxy protocol header
static final AttributeKey<InetAddress> PROXY_REMOTE_ADDRESS = AttributeKey.newInstance("proxyRemoteAddress");
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
if (!(msg instanceof HAProxyMessage proxyMessage)) {
ctx.pipeline().remove(this);
ctx.fireChannelRead(msg);
return;
}
try {
final String sourceAddress = proxyMessage.sourceAddress();
if (sourceAddress != null) {
ctx.channel().attr(PROXY_REMOTE_ADDRESS).set(InetAddress.getByName(sourceAddress));
} else {
logger.warn("PROXY protocol message has no source address");
}
} finally {
proxyMessage.release();
ctx.pipeline().remove(this);
}
}
}

View File

@ -0,0 +1,48 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import io.micrometer.core.instrument.Metrics;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.ProtocolDetectionResult;
import io.netty.handler.codec.haproxy.HAProxyMessageDecoder;
import io.netty.handler.codec.haproxy.HAProxyProtocolVersion;
import java.util.List;
class ProxyProtocolHandler extends ByteToMessageDecoder {
private static final String PROXY_PROTOCOL_DETECTED_NAME =
name(ProxyProtocolHandler.class, "proxyProtocol");
@Override
protected void decode(final ChannelHandlerContext ctx, final ByteBuf in, final List<Object> out) {
// This does not advance the read index, so the bytes we accumulate via ByteToMessageDecoder are always forwarded
// once we get enough (either to an HAProxyMessageDecoder, or just to the rest of the pipeline)
final ProtocolDetectionResult<HAProxyProtocolVersion> detected = HAProxyMessageDecoder.detectProtocol(in);
switch (detected.state()) {
case NEEDS_MORE_DATA:
break;
case DETECTED:
// There is a valid proxy-protocol header. Replace ourselves with the actual decoder (which will forward our
// accumulated bytes to the decoder via handlerRemoved)
Metrics.counter(PROXY_PROTOCOL_DETECTED_NAME, "detected", "true").increment();
ctx.pipeline().replace(this, "haproxy-decoder", new HAProxyMessageDecoder());
break;
case INVALID:
// No header, we can just forward any bytes we've accumulated.
Metrics.counter(PROXY_PROTOCOL_DETECTED_NAME, "detected", "false").increment();
ctx.pipeline().remove(this);
break;
}
}
}

View File

@ -0,0 +1,164 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import com.google.common.annotations.VisibleForTesting;
import io.netty.handler.ssl.ApplicationProtocolConfig;
import io.netty.handler.ssl.ApplicationProtocolNames;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.util.Mapping;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.security.Key;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.UnrecoverableKeyException;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.security.cert.CertificateParsingException;
import java.security.cert.X509Certificate;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class SniMapper {
private static final Logger logger = LoggerFactory.getLogger(SniMapper.class);
private SniMapper() {
}
/// Build a [Mapping] from a [KeyStore] that maps from domain (via SAN) to [io.netty.handler.ssl.SslContext] that
/// can be used with an [io.netty.handler.ssl.SniHandler]. The provided keystore may contain multiple certificates
/// for a single domain, all matching certificates will be included in the corresponding SslContext. The domain for
/// a certificate is only determined by the SAN and all certificates must have a SAN. The returned [Mapping] returns
/// an arbitrary set of certificates if none of the certificates in the keystore match a requested domain, as
/// permitted by RFC-6066.
///
/// @param keyStorePath The path to the [KeyStore]
/// @param keyStorePassword The password for the keyStore
/// @return A [Mapping] that maps domains to the corresponding [SslContext] containing the certificates for that
/// domain
public static Mapping<String, SslContext> buildSniMapping(final String keyStorePath, final String keyStorePassword)
throws IOException {
try (final FileInputStream fis = new FileInputStream(keyStorePath)) {
return buildSniMapping(fis, keyStorePassword);
}
}
@VisibleForTesting
static Mapping<String, SslContext> buildSniMapping(final InputStream keyStore, final String keyStorePassword)
throws IOException {
try {
final Map<String, KeyStore> domainKeyStores = partitionByDomain(keyStore, keyStorePassword.toCharArray());
final Map<String, SslContext> sslContextsByDomain = new HashMap<>();
for (final Map.Entry<String, KeyStore> entry : domainKeyStores.entrySet()) {
final KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
kmf.init(entry.getValue(), keyStorePassword.toCharArray());
sslContextsByDomain.put(entry.getKey(), buildSslContext(kmf));
}
// Netty expects the SNI mapping to always return an SslContext. Per RFC-6066 it's valid to continue the handshake
// on an SNI mismatch, it's the client's responsibility to check the returned certificate's SNI. We sort first so
// our choice of certificate is deterministic.
final SslContext defaultSslContext = sslContextsByDomain.entrySet().stream()
.min(Map.Entry.comparingByKey())
.orElseThrow(() -> new IllegalArgumentException("Key store contained no certificates"))
.getValue();
logger.info("Loaded TLS contexts for domains: {}", sslContextsByDomain.keySet());
return hostname -> sslContextsByDomain.getOrDefault(hostname, defaultSslContext);
} catch (NoSuchAlgorithmException | KeyStoreException | CertificateException | UnrecoverableKeyException e) {
throw new IOException("Failed to load keystore", e);
}
}
private static SslContext buildSslContext(final KeyManagerFactory kmf) throws SSLException {
return SslContextBuilder.forServer(kmf)
.applicationProtocolConfig(new ApplicationProtocolConfig(
ApplicationProtocolConfig.Protocol.ALPN,
ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE,
ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT,
ApplicationProtocolNames.HTTP_2))
.protocols("TLSv1.3")
.build();
}
private static Map<String, KeyStore> partitionByDomain(final InputStream keystoreStream, final char[] keystorePassword)
throws KeyStoreException, CertificateException, UnrecoverableKeyException, IOException, NoSuchAlgorithmException {
final KeyStore keyStore = KeyStore.getInstance("PKCS12");
keyStore.load(keystoreStream, keystorePassword);
// Group key entries by the domain(s) in each certificate's SANs
final Map<String, KeyStore> domainKeyStores = new HashMap<>();
for (final String alias : Collections.list(keyStore.aliases())) {
if (!keyStore.isKeyEntry(alias)) {
continue;
}
final Certificate[] chain = keyStore.getCertificateChain(alias);
if (chain == null || chain.length == 0) {
continue;
}
final X509Certificate leaf = (X509Certificate) chain[0];
final Key key = keyStore.getKey(alias, keystorePassword);
for (final String domain : getDnsNames(leaf)) {
domainKeyStores
.computeIfAbsent(domain, _ -> newEmptyKeyStore())
.setKeyEntry(alias, key, keystorePassword, chain);
}
}
if (domainKeyStores.isEmpty()) {
throw new IOException("Keystore contains no usable key entries with DNS names");
}
return domainKeyStores;
}
private static KeyStore newEmptyKeyStore() {
try {
final KeyStore ks = KeyStore.getInstance("PKCS12");
ks.load(null, null);
return ks;
} catch (KeyStoreException | IOException | NoSuchAlgorithmException | CertificateException e) {
// All Java runtime implementations are required to support PKCS12, and we aren't loading anything from disk
// so an exception here is impossible.
throw new AssertionError("Failed to create empty keystore", e);
}
}
/// Extract all DNS-type SAN names on this certificate
private static List<String> getDnsNames(final X509Certificate cert) throws CertificateParsingException, IOException {
final Collection<List<?>> sans = cert.getSubjectAlternativeNames();
if (sans == null) {
throw new IOException("Certificate did not have SAN extension");
}
final List<String> dnsSans = sans.stream()
// GeneralName type 2 = dNSName. See getSubjectAlternativeNames
.filter(san -> (int) san.getFirst() == 2)
.map(san -> (String) san.get(1))
.map(s -> s.toLowerCase(Locale.ROOT))
.toList();
if (dnsSans.isEmpty()) {
throw new IOException("Certificate did not have a DNS SAN entry");
}
return dnsSans;
}
}

View File

@ -40,8 +40,6 @@ public class JettyHttpConfigurationCustomizer implements Container.Listener, Lif
httpConfiguration.setNotifyRemoteAsyncErrors(false);
}
}
c.addBean(new JettyConnectionMetrics(Metrics.globalRegistry));
}
}

View File

@ -14,7 +14,7 @@ import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
/**
* Delegates request events to a listener that captures and reports request-level metrics.
*
* @see MetricsHttpChannelListener
* @see MetricsHttpEventHandler
* @see MetricsRequestEventListener
*/
public class MetricsApplicationEventListener implements ApplicationEventListener {
@ -23,7 +23,7 @@ public class MetricsApplicationEventListener implements ApplicationEventListener
public MetricsApplicationEventListener(final TrafficSource trafficSource, final ClientReleaseManager clientReleaseManager) {
if (trafficSource == TrafficSource.HTTP) {
throw new IllegalArgumentException("Use " + MetricsHttpChannelListener.class.getName() + " for HTTP traffic");
throw new IllegalArgumentException("Use " + MetricsHttpEventHandler.class.getName() + " for HTTP traffic");
}
this.metricsRequestEventListener = new MetricsRequestEventListener(trafficSource, clientReleaseManager);
}

View File

@ -1,226 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.metrics;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.net.HttpHeaders;
import io.dropwizard.core.setup.Environment;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags;
import jakarta.ws.rs.container.ContainerRequestContext;
import jakarta.ws.rs.container.ContainerResponseContext;
import jakarta.ws.rs.container.ContainerResponseFilter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.Set;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.eclipse.jetty.server.Connector;
import org.eclipse.jetty.server.HttpChannel;
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.util.component.Container;
import org.eclipse.jetty.util.component.LifeCycle;
import org.glassfish.jersey.server.ExtendedUriInfo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.util.logging.UriInfoUtil;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
/**
* Gathers and reports HTTP request metrics at the Jetty container level, which sits above Jersey. In order to get
* templated Jersey request paths, it implements {@link jakarta.ws.rs.container.ContainerResponseFilter}, in order to give
* itself access to the template. It is limited to {@link TrafficSource#HTTP} requests.
* <p>
* It implements {@link LifeCycle.Listener} without overriding methods, so that it can be an event listener that
* Dropwizard will attach to the container&mdash;the {@link Container.Listener} implementation is where it attaches
* itself to any {@link Connector}s.
*
* @see MetricsRequestEventListener
*/
public class MetricsHttpChannelListener implements HttpChannel.Listener, Container.Listener, LifeCycle.Listener,
ContainerResponseFilter {
private static final Set<String> EXPECTED_HTTP_METHODS =
Set.of("GET", "HEAD", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH");
private static final Logger logger = LoggerFactory.getLogger(MetricsHttpChannelListener.class);
private record RequestInfo(String path, String method, int statusCode, @Nullable String userAgent) {
}
private final ClientReleaseManager clientReleaseManager;
private final Set<String> servletPaths;
// Use the same counter namespace as MetricsRequestEventListener for continuity
public static final String REQUEST_COUNTER_NAME = MetricsRequestEventListener.REQUEST_COUNTER_NAME;
public static final String REQUESTS_BY_VERSION_COUNTER_NAME = MetricsRequestEventListener.REQUESTS_BY_VERSION_COUNTER_NAME;
@VisibleForTesting
static final String RESPONSE_BYTES_COUNTER_NAME = MetricsRequestEventListener.RESPONSE_BYTES_COUNTER_NAME;
@VisibleForTesting
static final String REQUEST_BYTES_COUNTER_NAME = MetricsRequestEventListener.REQUEST_BYTES_COUNTER_NAME;
@VisibleForTesting
static final String URI_INFO_PROPERTY_NAME = MetricsHttpChannelListener.class.getName() + ".uriInfo";
@VisibleForTesting
static final String PATH_TAG = "path";
@VisibleForTesting
static final String METHOD_TAG = "method";
@VisibleForTesting
static final String STATUS_CODE_TAG = "status";
@VisibleForTesting
static final String TRAFFIC_SOURCE_TAG = "trafficSource";
private final MeterRegistry meterRegistry;
public MetricsHttpChannelListener(final ClientReleaseManager clientReleaseManager, final Set<String> servletPaths) {
this(Metrics.globalRegistry, clientReleaseManager, servletPaths);
}
@VisibleForTesting
MetricsHttpChannelListener(final MeterRegistry meterRegistry, final ClientReleaseManager clientReleaseManager,
final Set<String> servletPaths) {
this.meterRegistry = meterRegistry;
this.clientReleaseManager = clientReleaseManager;
this.servletPaths = servletPaths;
}
public void configure(final Environment environment) {
// register as ContainerResponseFilter
environment.jersey().register(this);
// hook into lifecycle events, to react to the Connector being added
environment.lifecycle().addEventListener(this);
}
@Override
public void onRequestFailure(final Request request, final Throwable failure) {
if (logger.isDebugEnabled()) {
final RequestInfo requestInfo = getRequestInfo(request);
logger.debug("Request failure: {} {} ({}) [{}] ",
requestInfo.method(),
requestInfo.path(),
requestInfo.userAgent(),
requestInfo.statusCode(), failure);
}
}
@Override
public void onResponseFailure(Request request, Throwable failure) {
if (failure instanceof org.eclipse.jetty.io.EofException) {
// the client disconnected early
return;
}
final RequestInfo requestInfo = getRequestInfo(request);
logger.warn("Response failure: {} {} ({}) [{}] ",
requestInfo.method(),
requestInfo.path(),
requestInfo.userAgent(),
requestInfo.statusCode(), failure);
}
@Override
public void onComplete(final Request request) {
final RequestInfo requestInfo = getRequestInfo(request);
@Nullable final UserAgent userAgent;
{
UserAgent parsedUserAgent;
try {
parsedUserAgent = UserAgentUtil.parseUserAgentString(requestInfo.userAgent());
} catch (final UnrecognizedUserAgentException e) {
parsedUserAgent = null;
}
userAgent = parsedUserAgent;
}
final Tag platformTag = UserAgentTagUtil.getPlatformTag(userAgent);
final List<Tag> tags = new ArrayList<>(5);
tags.add(Tag.of(PATH_TAG, requestInfo.path()));
tags.add(Tag.of(METHOD_TAG, requestInfo.method()));
tags.add(Tag.of(STATUS_CODE_TAG, String.valueOf(requestInfo.statusCode())));
tags.add(Tag.of(TRAFFIC_SOURCE_TAG, TrafficSource.HTTP.name().toLowerCase()));
tags.add(platformTag);
final Optional<Tag> maybeClientVersionTag =
UserAgentTagUtil.getClientVersionTag(userAgent, clientReleaseManager);
maybeClientVersionTag.ifPresent(tags::add);
meterRegistry.counter(REQUEST_COUNTER_NAME, tags).increment();
meterRegistry.counter(RESPONSE_BYTES_COUNTER_NAME, tags).increment(request.getResponse().getContentCount());
meterRegistry.counter(REQUEST_BYTES_COUNTER_NAME, tags).increment(request.getContentRead());
maybeClientVersionTag.ifPresent(clientVersionTag -> meterRegistry.counter(REQUESTS_BY_VERSION_COUNTER_NAME,
Tags.of(clientVersionTag, platformTag))
.increment());
}
@Override
public void beanAdded(final Container parent, final Object child) {
if (child instanceof Connector connector) {
connector.addBean(this);
}
}
@Override
public void beanRemoved(final Container parent, final Object child) {
}
@Override
public void filter(final ContainerRequestContext requestContext, final ContainerResponseContext responseContext)
throws IOException {
requestContext.setProperty(URI_INFO_PROPERTY_NAME, requestContext.getUriInfo());
}
private RequestInfo getRequestInfo(Request request) {
final String path = Optional.ofNullable(request.getAttribute(URI_INFO_PROPERTY_NAME))
.map(attr -> UriInfoUtil.getPathTemplate((ExtendedUriInfo) attr))
.orElseGet(() ->
Optional.ofNullable(request.getPathInfo())
.filter(servletPaths::contains)
.orElse("unknown")
);
// Response cannot be null, but its status might not always reflect an actual response status, since it gets
// initialized to 200
final int status = request.getResponse().getStatus();
@Nullable final String userAgent = request.getHeader(HttpHeaders.USER_AGENT);
return new RequestInfo(path, normalizeMethod(request.getMethod()), status, userAgent);
}
static String normalizeMethod(@Nullable final String method) {
if (StringUtils.isBlank(method)) {
return "unknown";
}
return EXPECTED_HTTP_METHODS.contains(method.toUpperCase(Locale.ROOT)) ? method : "unknown";
}
}

View File

@ -0,0 +1,280 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.metrics;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.net.HttpHeaders;
import io.dropwizard.core.setup.Environment;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags;
import jakarta.validation.constraints.NotNull;
import jakarta.ws.rs.container.ContainerRequestContext;
import jakarta.ws.rs.container.ContainerResponseContext;
import jakarta.ws.rs.container.ContainerResponseFilter;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.io.Content;
import org.eclipse.jetty.server.Handler;
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.handler.EventsHandler;
import org.eclipse.jetty.util.component.LifeCycle;
import org.glassfish.jersey.server.ExtendedUriInfo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.util.logging.UriInfoUtil;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
/**
* Gathers and reports HTTP request metrics at the Jetty container level, which sits above Jersey. In order to get
* templated Jersey request paths, it adds a {@link jakarta.ws.rs.container.ContainerResponseFilter}, in order to give
* itself access to the template. It is limited to {@link TrafficSource#HTTP} requests.
*
* @see MetricsRequestEventListener
*/
public class MetricsHttpEventHandler extends EventsHandler {
private static final Logger logger = LoggerFactory.getLogger(MetricsHttpEventHandler.class);
private static final Set<String> EXPECTED_HTTP_METHODS =
Set.of("GET", "HEAD", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH");
private final ClientReleaseManager clientReleaseManager;
private final Set<String> servletPaths;
// Use the same counter namespace as MetricsRequestEventListener for continuity
public static final String REQUEST_COUNTER_NAME = MetricsRequestEventListener.REQUEST_COUNTER_NAME;
public static final String REQUESTS_BY_VERSION_COUNTER_NAME = MetricsRequestEventListener.REQUESTS_BY_VERSION_COUNTER_NAME;
@VisibleForTesting
static final String RESPONSE_BYTES_COUNTER_NAME = MetricsRequestEventListener.RESPONSE_BYTES_COUNTER_NAME;
@VisibleForTesting
static final String REQUEST_BYTES_COUNTER_NAME = MetricsRequestEventListener.REQUEST_BYTES_COUNTER_NAME;
@VisibleForTesting
static final String REQUEST_INFO_PROPERTY_NAME = MetricsHttpEventHandler.class.getName() + ".requestInfo";
@VisibleForTesting
static final String PATH_TAG = "path";
@VisibleForTesting
static final String METHOD_TAG = "method";
@VisibleForTesting
static final String STATUS_CODE_TAG = "status";
@VisibleForTesting
static final String TRAFFIC_SOURCE_TAG = "trafficSource";
private final MeterRegistry meterRegistry;
@VisibleForTesting
MetricsHttpEventHandler(
final Handler handler,
final MeterRegistry meterRegistry,
final ClientReleaseManager clientReleaseManager,
final Set<String> servletPaths) {
super(handler);
this.meterRegistry = meterRegistry;
this.clientReleaseManager = clientReleaseManager;
this.servletPaths = servletPaths;
}
/**
* Configure a {@link MetricsHttpEventHandler}
*
* @param environment A dropwizard {@link org.eclipse.jetty.util.component.Environment}
* @param meterRegistry The meter registry to register metrics with
* @param clientReleaseManager A {@link ClientReleaseManager} that determines what tags to include with metrics
* @param servletPaths An allow-list of paths to include in metric tags for requests that are handled by above
* Jersey
*/
public static void configure(final Environment environment, final MeterRegistry meterRegistry,
final ClientReleaseManager clientReleaseManager, final Set<String> servletPaths) {
// register a filter that will set the initial request info
environment.jersey().register(new SetInfoRequestFilter());
// hook into lifecycle events, to react to the Connector being added
environment.lifecycle().addEventListener(new LifeCycle.Listener() {
@Override
public void lifeCycleStarting(LifeCycle event) {
if (event instanceof Server server) {
server.setHandler(
new MetricsHttpEventHandler(server.getHandler(), meterRegistry, clientReleaseManager, servletPaths));
}
}
});
}
private void onResponseFailure(Request request, int status, Throwable failure) {
if (failure instanceof org.eclipse.jetty.io.EofException) {
// the client disconnected early
return;
}
final RequestInfo requestInfo = getRequestInfo(request);
logger.warn("Response failure: {} {} ({}) [{}] ",
requestInfo.method,
requestInfo.path,
requestInfo.userAgent,
status,
failure);
}
@Override
public void onComplete(Request request, int status, HttpFields headers, Throwable failure) {
super.onComplete(request, status, headers, failure);
if (failure != null) {
onResponseFailure(request, status, failure);
}
final RequestInfo requestInfo = getRequestInfo(request);
@Nullable final UserAgent userAgent;
{
UserAgent parsedUserAgent;
try {
parsedUserAgent = UserAgentUtil.parseUserAgentString(requestInfo.userAgent);
} catch (final UnrecognizedUserAgentException e) {
parsedUserAgent = null;
}
userAgent = parsedUserAgent;
}
final Tag platformTag = UserAgentTagUtil.getPlatformTag(userAgent);
final List<Tag> tags = new ArrayList<>(5);
tags.add(Tag.of(PATH_TAG, requestInfo.path));
tags.add(Tag.of(METHOD_TAG, requestInfo.method));
tags.add(Tag.of(STATUS_CODE_TAG, String.valueOf(status)));
tags.add(Tag.of(TRAFFIC_SOURCE_TAG, TrafficSource.HTTP.name().toLowerCase()));
tags.add(platformTag);
final Optional<Tag> maybeClientVersionTag =
UserAgentTagUtil.getClientVersionTag(userAgent, clientReleaseManager);
maybeClientVersionTag.ifPresent(tags::add);
meterRegistry.counter(REQUEST_COUNTER_NAME, tags).increment();
meterRegistry.counter(RESPONSE_BYTES_COUNTER_NAME, tags).increment(requestInfo.responseBytes);
meterRegistry.counter(REQUEST_BYTES_COUNTER_NAME, tags).increment(requestInfo.requestBytes);
maybeClientVersionTag.ifPresent(clientVersionTag -> meterRegistry.counter(REQUESTS_BY_VERSION_COUNTER_NAME,
Tags.of(clientVersionTag, platformTag))
.increment());
}
@Override
protected void onRequestRead(final Request request, final Content.Chunk chunk) {
super.onRequestRead(request, chunk);
if (chunk != null) {
getRequestInfo(request).requestBytes += chunk.remaining();
}
}
@Override
protected void onResponseWrite(final Request request, final boolean last, final ByteBuffer content) {
super.onResponseWrite(request, last, content);
if (content != null) {
getRequestInfo(request).responseBytes += content.remaining();
}
}
private RequestInfo getRequestInfo(Request request) {
Object obj = request.getAttribute(REQUEST_INFO_PROPERTY_NAME);
if (obj != null && obj instanceof RequestInfo requestInfo) {
return requestInfo;
}
// Our ContainerResponseFilter has not run yet. It should eventually run, and will override the path we set here.
// It may not run if this is a websocket upgrade request, a request handled by jetty directly, or a higher priority
// filter aborted the request by throwing an exception, in which case we'll use this path. To avoid giving every
// incorrect path a unique tag we check against a configured list of paths that we know would skip the filter.
final RequestInfo newInfo = new RequestInfo(
Optional.ofNullable(request.getHttpURI().getPath()).filter(servletPaths::contains).orElse("unknown"),
normalizeMethod(request.getMethod()),
request.getHeaders().get(HttpHeaders.USER_AGENT));
request.setAttribute(REQUEST_INFO_PROPERTY_NAME, newInfo);
return newInfo;
}
@VisibleForTesting
static class RequestInfo {
private String path;
private final String method;
private final @Nullable String userAgent;
private long requestBytes;
private long responseBytes;
RequestInfo(@NotNull String path, @NotNull String method, @Nullable String userAgent) {
this.path = path;
this.method = method;
this.userAgent = userAgent;
this.requestBytes = 0;
this.responseBytes = 0;
}
@Override
public boolean equals(final Object o) {
if (o == null || getClass() != o.getClass()) {
return false;
}
RequestInfo that = (RequestInfo) o;
return requestBytes == that.requestBytes && responseBytes == that.responseBytes && Objects.equals(path, that.path)
&& Objects.equals(method, that.method) && Objects.equals(userAgent, that.userAgent);
}
@Override
public int hashCode() {
return Objects.hash(path, method, userAgent, requestBytes, responseBytes);
}
}
@VisibleForTesting
static class SetInfoRequestFilter implements ContainerResponseFilter {
@Override
public void filter(final ContainerRequestContext requestContext, final ContainerResponseContext responseContext) {
// Construct the templated URI path. If no matching path is found, this will be ""
final String path = UriInfoUtil.getPathTemplate((ExtendedUriInfo) requestContext.getUriInfo());
final Object obj = requestContext.getProperty(REQUEST_INFO_PROPERTY_NAME);
if (obj != null && obj instanceof RequestInfo requestInfo) {
requestInfo.path = path;
} else {
requestContext.setProperty(REQUEST_INFO_PROPERTY_NAME,
new RequestInfo(path, requestContext.getMethod(), requestContext.getHeaderString(HttpHeaders.USER_AGENT)));
}
}
}
static String normalizeMethod(@Nullable final String method) {
if (StringUtils.isBlank(method)) {
return "unknown";
}
return EXPECTED_HTTP_METHODS.contains(method.toUpperCase(Locale.ROOT)) ? method : "unknown";
}
}

View File

@ -31,7 +31,7 @@ import org.whispersystems.websocket.WebSocketResourceProvider;
/**
* Gathers and reports request-level metrics for WebSocket traffic only.
* For HTTP traffic, use {@link MetricsHttpChannelListener}.
* For HTTP traffic, use {@link MetricsHttpEventHandler}.
*/
public class MetricsRequestEventListener implements RequestEventListener {
@ -69,7 +69,7 @@ public class MetricsRequestEventListener implements RequestEventListener {
this(trafficSource, Metrics.globalRegistry, clientReleaseManager);
if (trafficSource == TrafficSource.HTTP) {
logger.warn("Use {} for HTTP traffic", MetricsHttpChannelListener.class.getName());
logger.warn("Use {} for HTTP traffic", MetricsHttpEventHandler.class.getName());
}
}
@ -107,7 +107,7 @@ public class MetricsRequestEventListener implements RequestEventListener {
final List<Tag> tags = new ArrayList<>();
tags.add(Tag.of(PATH_TAG, UriInfoUtil.getPathTemplate(event.getUriInfo())));
tags.add(Tag.of(METHOD_TAG, MetricsHttpChannelListener.normalizeMethod(event.getContainerRequest().getMethod())));
tags.add(Tag.of(METHOD_TAG, MetricsHttpEventHandler.normalizeMethod(event.getContainerRequest().getMethod())));
tags.add(Tag.of(STATUS_CODE_TAG, String.valueOf(Optional
.ofNullable(event.getContainerResponse())
.map(ContainerResponse::getStatus)

View File

@ -21,7 +21,7 @@ import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Optional;
import org.eclipse.jetty.util.resource.Resource;
import org.eclipse.jetty.util.resource.PathResourceFactory;
import org.eclipse.jetty.util.security.CertificateUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -37,7 +37,7 @@ public class TlsCertificateExpirationUtil {
final KeyStore keyStore;
try {
keyStore = CertificateUtils.getKeyStore(Resource.newResource(keyStorePath), keyStoreType, keyStoreProvider,
keyStore = CertificateUtils.getKeyStore(new PathResourceFactory().newResource(keyStorePath), keyStoreType, keyStoreProvider,
keyStorePassword);
} catch (Exception e) {

View File

@ -3,6 +3,7 @@ package org.whispersystems.textsecuregcm.redis;
import com.google.common.annotations.VisibleForTesting;
import io.github.resilience4j.circuitbreaker.CircuitBreaker;
import io.lettuce.core.ClientOptions;
import io.lettuce.core.MaintNotificationsConfig;
import io.lettuce.core.RedisClient;
import io.lettuce.core.RedisException;
import io.lettuce.core.RedisURI;
@ -57,20 +58,10 @@ public class FaultTolerantRedisClient {
this.name = name;
// Lettuce will issue a CLIENT SETINFO command unconditionally if these fields are set (and they are by default),
// which can generate a bunch of spurious warnings in versions of Redis before 7.2.0.
//
// See:
//
// - https://github.com/redis/lettuce/pull/2823
// - https://github.com/redis/lettuce/issues/2817
redisUri.setClientName(null);
redisUri.setLibraryName(null);
redisUri.setLibraryVersion(null);
this.redisClient = RedisClient.create(clientResourcesBuilder.build(), redisUri);
final ClientOptions.Builder clientOptionsBuilder = ClientOptions.builder()
.disconnectedBehavior(ClientOptions.DisconnectedBehavior.REJECT_COMMANDS)
.maintNotificationsConfig(MaintNotificationsConfig.disabled())
// for asynchronous commands
.timeoutOptions(TimeoutOptions.builder()
.fixedTimeout(commandTimeout)

View File

@ -9,6 +9,7 @@ import io.github.resilience4j.core.IntervalFunction;
import io.github.resilience4j.retry.Retry;
import io.github.resilience4j.retry.RetryConfig;
import io.lettuce.core.ClientOptions;
import io.lettuce.core.MaintNotificationsConfig;
import io.lettuce.core.RedisException;
import io.lettuce.core.RedisURI;
import io.lettuce.core.TimeoutOptions;
@ -69,28 +70,15 @@ public class FaultTolerantRedisClusterClient {
this.name = name;
// Lettuce will issue a CLIENT SETINFO command unconditionally if these fields are set (and they are by default),
// which can generate a bunch of spurious warnings in versions of Redis before 7.2.0.
//
// See:
//
// - https://github.com/redis/lettuce/pull/2823
// - https://github.com/redis/lettuce/issues/2817
redisUris.forEach(redisUri -> {
redisUri.setClientName(null);
redisUri.setLibraryName(null);
redisUri.setLibraryVersion(null);
});
final LettuceShardCircuitBreaker lettuceShardCircuitBreaker =
new LettuceShardCircuitBreaker(name, circuitBreakerConfigurationName);
this.clusterClient = RedisClusterClient.create(
clientResourcesBuilder.nettyCustomizer(lettuceShardCircuitBreaker).
build(),
clientResourcesBuilder.nettyCustomizer(lettuceShardCircuitBreaker)
.build(),
redisUris);
final ClusterClientOptions.Builder clusterClientOptionsBuilder = ClusterClientOptions.builder()
final ClusterClientOptions.Builder clusterClientOptionsBuilder = (ClusterClientOptions.Builder) ClusterClientOptions.builder()
.disconnectedBehavior(ClientOptions.DisconnectedBehavior.REJECT_COMMANDS)
.validateClusterNodeMembership(false)
.topologyRefreshOptions(ClusterTopologyRefreshOptions.builder()
@ -100,7 +88,8 @@ public class FaultTolerantRedisClusterClient {
.timeoutOptions(TimeoutOptions.builder()
.fixedTimeout(commandTimeout)
.build())
.publishOnScheduler(true);
.publishOnScheduler(true)
.maintNotificationsConfig(MaintNotificationsConfig.disabled());
NettyUtil.setSocketTimeoutsIfApplicable(clusterClientOptionsBuilder);

View File

@ -1,33 +0,0 @@
package org.whispersystems.textsecuregcm.util;
import io.dropwizard.lifecycle.Managed;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.management.MBeanServer;
import java.lang.management.ManagementFactory;
import java.util.concurrent.ScheduledExecutorService;
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
public class JmxDumper implements Managed {
private static final Logger log = LoggerFactory.getLogger(JmxDumper.class);
private final ScheduledExecutorService executor;
public JmxDumper(final ScheduledExecutorService executor) {
this.executor = executor;
}
@Override
public void start() throws Exception {
// executor.schedule()
}
private void dump() {
MBeanServer mbs = ManagementFactory.getPlatformMBeanServer();
}
}

View File

@ -8,14 +8,14 @@ package org.whispersystems.textsecuregcm.websocket;
import static org.whispersystems.textsecuregcm.util.HeaderUtils.basicCredentialsFromAuthHeader;
import com.google.common.net.HttpHeaders;
import javax.annotation.Nullable;
import io.dropwizard.auth.basic.BasicCredentials;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import java.util.Optional;
import javax.annotation.Nullable;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.websocket.auth.InvalidCredentialsException;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import java.util.Optional;
public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<AuthenticatedDevice> {
@ -27,7 +27,7 @@ public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<Aut
}
@Override
public Optional<AuthenticatedDevice> authenticate(final UpgradeRequest request)
public Optional<AuthenticatedDevice> authenticate(final JettyServerUpgradeRequest request)
throws InvalidCredentialsException {
@Nullable final String authHeader = request.getHeader(HttpHeaders.AUTHORIZATION);

View File

@ -15,7 +15,6 @@ import jakarta.ws.rs.core.HttpHeaders;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import org.apache.commons.lang3.RandomStringUtils;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.glassfish.jersey.server.ManagedAsync;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
@ -35,7 +34,6 @@ public class BufferingInterceptorIntegrationTest {
environment.jersey().register(testController);
environment.jersey().register(new BufferingInterceptor());
environment.jersey().register(new VirtualExecutorServiceProvider("virtual-thread-", 10));
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
}
}

View File

@ -18,28 +18,29 @@ import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.testing.junit5.DropwizardAppExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import jakarta.servlet.DispatcherType;
import jakarta.servlet.ServletRegistration;
import java.io.IOException;
import java.net.URI;
import java.nio.ByteBuffer;
import java.time.Duration;
import java.util.EnumSet;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.eclipse.jetty.websocket.api.Callback;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.asn.AsnInfoProvider;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.filters.PriorityFilter;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.push.ProvisioningManager;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
@ -52,6 +53,7 @@ import org.whispersystems.websocket.messages.WebSocketMessage;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ExtendWith(DropwizardExtensionsSupport.class)
@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
public class ProvisioningTimeoutIntegrationTest {
private static final DropwizardAppExtension<Configuration> DROPWIZARD_APP_EXTENSION =
@ -77,9 +79,9 @@ public class ProvisioningTimeoutIntegrationTest {
CompletableFuture<String> provisioningAddressFuture = new CompletableFuture<>();
@Override
public void onWebSocketBinary(final byte[] payload, final int offset, final int length) {
public void onWebSocketBinary(final ByteBuffer payload, final Callback callback) {
try {
WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length);
WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload);
if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.REQUEST_MESSAGE
&& webSocketMessage.getRequestMessage().getPath().equals("/v1/address")) {
MessageProtos.ProvisioningAddress provisioningAddress =
@ -92,7 +94,7 @@ public class ProvisioningTimeoutIntegrationTest {
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(e);
}
super.onWebSocketBinary(payload, offset, length);
super.onWebSocketBinary(payload, callback);
}
}
@ -106,21 +108,17 @@ public class ProvisioningTimeoutIntegrationTest {
final WebSocketEnvironment<AuthenticatedDevice> webSocketEnvironment =
new WebSocketEnvironment<>(environment, webSocketConfiguration);
environment.servlets()
.addFilter("RemoteAddressFilter", new RemoteAddressFilter())
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.setConnectListener(
new ProvisioningConnectListener(mock(ProvisioningManager.class), () -> mock(AsnInfoProvider.class), mock(ClientReleaseManager.class), scheduler, Duration.ofSeconds(5)));
final WebSocketResourceProviderFactory<AuthenticatedDevice> webSocketServlet =
new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedDevice.class,
webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME);
REMOTE_ADDRESS_ATTRIBUTE_NAME);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
final ServletRegistration.Dynamic websocketServlet = environment.servlets()
.addServlet("WebSocket", webSocketServlet);
websocketServlet.addMapping("/websocket");
websocketServlet.setAsyncSupported(true);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), (servletContext, container) -> {
container.addMapping("/websocket", webSocketServlet);
PriorityFilter.ensureFilter(servletContext, new RemoteAddressFilter());
});
}
}

View File

@ -9,8 +9,6 @@ import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.testing.junit5.DropwizardAppExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import jakarta.servlet.DispatcherType;
import jakarta.servlet.ServletRegistration;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.PUT;
import jakarta.ws.rs.Path;
@ -20,19 +18,20 @@ import jakarta.ws.rs.core.HttpHeaders;
import jakarta.ws.rs.core.MediaType;
import java.io.IOException;
import java.net.URI;
import java.util.EnumSet;
import java.util.Optional;
import org.apache.commons.lang3.RandomStringUtils;
import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.glassfish.jersey.server.ManagedAsync;
import org.glassfish.jersey.server.ServerProperties;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.filters.PriorityFilter;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
@ -41,6 +40,7 @@ import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ExtendWith(DropwizardExtensionsSupport.class)
@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
public class WebsocketResourceProviderIntegrationTest {
private static final DropwizardAppExtension<Configuration> DROPWIZARD_APP_EXTENSION =
new DropwizardAppExtension<>(TestApplication.class);
@ -72,9 +72,6 @@ public class WebsocketResourceProviderIntegrationTest {
new WebSocketEnvironment<>(environment, webSocketConfiguration);
environment.jersey().register(testController);
environment.servlets()
.addFilter("RemoteAddressFilter", new RemoteAddressFilter())
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.jersey().register(testController);
webSocketEnvironment.jersey().register(new RemoteAddressFilter());
webSocketEnvironment.setAuthenticator(upgradeRequest -> Optional.of(mock(AuthenticatedDevice.class)));
@ -85,15 +82,13 @@ public class WebsocketResourceProviderIntegrationTest {
final WebSocketResourceProviderFactory<AuthenticatedDevice> webSocketServlet =
new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedDevice.class,
webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME);
REMOTE_ADDRESS_ATTRIBUTE_NAME);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), (servletContext, container) -> {
container.addMapping("/websocket", webSocketServlet);
PriorityFilter.ensureFilter(servletContext, new RemoteAddressFilter());
});
final ServletRegistration.Dynamic websocketServlet =
environment.servlets().addServlet("WebSocket", webSocketServlet);
websocketServlet.addMapping("/websocket");
websocketServlet.setAsyncSupported(true);
}
}

View File

@ -13,35 +13,47 @@ import io.dropwizard.testing.ConfigOverride;
import io.dropwizard.testing.junit5.DropwizardAppExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.util.Resources;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.stub.MetadataUtils;
import jakarta.ws.rs.client.Client;
import jakarta.ws.rs.core.Response;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.ServerSocket;
import java.net.URI;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Semaphore;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import org.eclipse.jetty.client.HttpClient;
import org.eclipse.jetty.http2.client.HTTP2Client;
import org.eclipse.jetty.http2.client.transport.HttpClientTransportOverHTTP2;
import org.eclipse.jetty.io.ClientConnector;
import org.eclipse.jetty.util.component.LifeCycle;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.StatusCode;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.configuration.OpenTelemetryConfiguration;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.signal.chat.account.AccountsAnonymousGrpc;
import org.signal.chat.account.CheckAccountExistenceRequest;
import org.signal.chat.account.CheckAccountExistenceResponse;
import org.signal.chat.common.IdentityType;
import org.signal.chat.common.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.NoopAwsSdkMetricPublisher;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtension;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema;
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
@ -58,13 +70,16 @@ class WhisperServerServiceTest {
System.setProperty("secrets.bundle.filename",
Resources.getResource("config/test-secrets-bundle.yml").getPath());
}
private static final int OMNIBUS_PORT = findAvailablePort();
private static final WebSocketClient webSocketClient = new WebSocketClient();
private static WebSocketClient webSocketClient;
private static WebSocketClient h2WebSocketClient;
private static final DropwizardAppExtension<WhisperServerConfiguration> EXTENSION = new DropwizardAppExtension<>(
WhisperServerService.class, Resources.getResource("config/test.yml").getPath(),
// Tables will be created by the local DynamoDbExtension
ConfigOverride.config("dynamoDbClient.initTables", "false"));
ConfigOverride.config("dynamoDbClient.initTables", "false"),
ConfigOverride.config("grpc.port", String.valueOf(OMNIBUS_PORT)));
@RegisterExtension
public static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(DynamoDbExtensionSchema.Tables.values());
@ -76,7 +91,13 @@ class WhisperServerServiceTest {
@BeforeAll
static void setUp() throws Exception {
final ClientConnector clientConnector = new ClientConnector();
final HTTP2Client http2Client = new HTTP2Client(clientConnector);
final HttpClient httpClient = new HttpClient(new HttpClientTransportOverHTTP2(http2Client));
h2WebSocketClient = new WebSocketClient(httpClient);
webSocketClient = new WebSocketClient();
webSocketClient.start();
h2WebSocketClient.start();
}
@Test
@ -100,8 +121,9 @@ class WhisperServerServiceTest {
assertEquals(200, healthCheck.getStatus());
}
@Test
void websocket() throws Exception {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void websocket(final boolean useH2) throws Exception {
// test unauthenticated websocket
final long start = System.currentTimeMillis();
@ -117,7 +139,7 @@ class WhisperServerServiceTest {
});
// Session is Closeable, but we intentionally keep it open so that we can confirm the container Lifecycle behavior
final Session session = webSocketClient.connect(testWebsocketListener,
final Session session = (useH2 ? h2WebSocketClient : webSocketClient).connect(testWebsocketListener,
URI.create(String.format("ws://localhost:%d/v1/websocket/", EXTENSION.getLocalPort())))
.join();
final long sessionTimestamp = Long.parseLong(session.getUpgradeResponse().getHeader(HeaderUtils.TIMESTAMP_HEADER));
@ -185,6 +207,51 @@ class WhisperServerServiceTest {
.build());
}
@Test
void omnibusWebsocket() throws Exception {
final HTTP2Client http2Client = new HTTP2Client(new ClientConnector());
final WebSocketClient h2WebSocketClient =
new WebSocketClient(new HttpClient(new HttpClientTransportOverHTTP2(http2Client)));
h2WebSocketClient.start();
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
final Session session = h2WebSocketClient.connect(testWebsocketListener,
URI.create(String.format("ws://localhost:%d/v1/websocket/", OMNIBUS_PORT)))
.join();
final WebSocketResponseMessage keepAlive = testWebsocketListener.doGet("/v1/keepalive").join();
assertEquals(200, keepAlive.getStatus());
final WebSocketResponseMessage whoami = testWebsocketListener.doGet("/v1/accounts/whoami").join();
assertEquals(401, whoami.getStatus());
session.close();
h2WebSocketClient.stop();
}
@Test
void omnibusGrpc() throws Exception {
final ManagedChannel channel = NettyChannelBuilder.forAddress("localhost", OMNIBUS_PORT)
.usePlaintext()
.build();
final Metadata metadata = new Metadata();
metadata.put(Metadata.Key.of("X-Forwarded-For", Metadata.ASCII_STRING_MARSHALLER), "127.0.0.1");
final AccountsAnonymousGrpc.AccountsAnonymousBlockingStub stub = AccountsAnonymousGrpc
.newBlockingStub(channel)
.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
final CheckAccountExistenceResponse response = stub.checkAccountExistence(
CheckAccountExistenceRequest.newBuilder()
.setServiceIdentifier(ServiceIdentifier.newBuilder()
.setIdentityType(IdentityType.IDENTITY_TYPE_ACI)
.setUuid(UUIDUtil.toByteString(UUID.randomUUID()))
.build())
.build());
assertFalse(response.getAccountExists());
channel.shutdownNow();
channel.awaitTermination(1, TimeUnit.SECONDS);
}
private static DynamoDbClient getDynamoDbClient() {
final AwsCredentialsProvider awsCredentialsProvider = EXTENSION.getConfiguration().getAwsCredentialsConfiguration()
.build();
@ -193,4 +260,12 @@ class WhisperServerServiceTest {
.buildSyncClient(awsCredentialsProvider, new NoopAwsSdkMetricPublisher());
}
private static int findAvailablePort() {
try (ServerSocket socket = new ServerSocket(0)) {
return socket.getLocalPort();
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
}

View File

@ -15,8 +15,8 @@ import java.util.List;
import java.util.Optional;
import java.util.UUID;
import javax.annotation.Nullable;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeResponse;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;

View File

@ -210,8 +210,7 @@ class PhoneVerificationTokenManagerTest {
@ParameterizedTest
@MethodSource
void verifyRecoveryPasswordManagerException(final Throwable recoveryPasswordManagerException)
throws ExecutionException, InterruptedException, TimeoutException {
void verifyRecoveryPasswordManagerException(final Throwable recoveryPasswordManagerException) {
final ContainerRequestContext containerRequestContext = mock(ContainerRequestContext.class);
final byte[] recoveryPassword = TestRandomUtil.nextBytes(16);
@ -219,11 +218,8 @@ class PhoneVerificationTokenManagerTest {
when(registrationRecoveryChecker.checkRegistrationRecoveryAttempt(containerRequestContext, PHONE_NUMBER))
.thenReturn(true);
@SuppressWarnings("unchecked") final CompletableFuture<Boolean> mockFuture = mock(CompletableFuture.class);
when(mockFuture.get(anyLong(), any())).thenThrow(recoveryPasswordManagerException);
when(registrationRecoveryPasswordsManager.verify(PHONE_NUMBER_IDENTIFIER, recoveryPassword))
.thenReturn(mockFuture);
.thenReturn(CompletableFuture.failedFuture(recoveryPasswordManagerException));
assertThrows(ServerErrorException.class, () -> phoneVerificationTokenManager.verify(containerRequestContext,
PHONE_NUMBER,

View File

@ -13,20 +13,17 @@ import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.testing.junit5.DropwizardAppExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import jakarta.servlet.DispatcherType;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.client.Client;
import jakarta.ws.rs.container.ContainerRequestContext;
import jakarta.ws.rs.core.Context;
import java.io.IOException;
import java.net.InetAddress;
import java.net.URI;
import java.nio.ByteBuffer;
import java.security.Principal;
import java.time.Duration;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
@ -35,14 +32,15 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import javax.security.auth.Subject;
import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.eclipse.jetty.util.HostPort;
import org.eclipse.jetty.websocket.api.Callback;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
@ -55,6 +53,7 @@ import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFa
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ExtendWith(DropwizardExtensionsSupport.class)
@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
class RemoteAddressFilterIntegrationTest {
private static final String WEBSOCKET_PREFIX = "/websocket";
@ -131,7 +130,7 @@ class RemoteAddressFilterIntegrationTest {
}
}
private static class ClientEndpoint implements WebSocketListener {
public static class ClientEndpoint implements Session.Listener.AutoDemanding {
private final String requestPath;
private final CompletableFuture<byte[]> responseFuture;
@ -145,22 +144,19 @@ class RemoteAddressFilterIntegrationTest {
}
@Override
public void onWebSocketConnect(final Session session) {
public void onWebSocketOpen(final Session session) {
final byte[] requestBytes = messageFactory.createRequest(Optional.of(1L), "GET", requestPath,
List.of("Accept: application/json"),
Optional.empty()).toByteArray();
try {
session.getRemote().sendBytes(ByteBuffer.wrap(requestBytes));
} catch (IOException e) {
throw new RuntimeException(e);
}
session.sendBinary(ByteBuffer.wrap(requestBytes), Callback.NOOP);
}
@Override
public void onWebSocketBinary(final byte[] payload, final int offset, final int length) {
public void onWebSocketBinary(final ByteBuffer payload, final Callback callback) {
try {
WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length);
WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload);
if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.RESPONSE_MESSAGE) {
assert 200 == webSocketMessage.getResponseMessage().getStatus();
@ -206,10 +202,6 @@ class RemoteAddressFilterIntegrationTest {
public void run(final Configuration configuration,
final Environment environment) throws Exception {
environment.servlets().addFilter("RemoteAddressFilterRemoteAddress", new RemoteAddressFilter())
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, REMOTE_ADDRESS_PATH,
WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH);
environment.jersey().register(new TestRemoteAddressController());
// WebSocket set up
@ -220,15 +212,14 @@ class RemoteAddressFilterIntegrationTest {
webSocketEnvironment.jersey().register(new TestWebSocketController());
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
WebSocketResourceProviderFactory<TestPrincipal> webSocketServlet = new WebSocketResourceProviderFactory<>(
webSocketEnvironment, TestPrincipal.class, webSocketConfiguration,
webSocketEnvironment, TestPrincipal.class,
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
environment.servlets().addServlet("WebSocketRemoteAddress", webSocketServlet)
.addMapping(WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), (servletContext, container) -> {
container.addMapping(WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH, webSocketServlet);
PriorityFilter.ensureFilter(servletContext, new RemoteAddressFilter());
});
}
}

View File

@ -0,0 +1,136 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.filters;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import io.dropwizard.core.Application;
import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.testing.ConfigOverride;
import io.dropwizard.testing.junit5.DropwizardAppExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import jakarta.servlet.DispatcherType;
import jakarta.servlet.Filter;
import jakarta.servlet.ServletContext;
import java.net.InetSocketAddress;
import java.util.EnumSet;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import org.eclipse.jetty.ee10.servlet.FilterHolder;
import org.eclipse.jetty.ee10.servlet.FilterMapping;
import org.eclipse.jetty.ee10.servlet.ServletContextHandler;
import org.eclipse.jetty.ee10.servlet.ServletHandler;
import org.eclipse.jetty.ee10.websocket.server.JettyWebSocketCreator;
import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.eclipse.jetty.http.HostPortHttpField;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.http.HttpScheme;
import org.eclipse.jetty.http.MetaData;
import org.eclipse.jetty.http2.api.Session;
import org.eclipse.jetty.http2.api.Stream;
import org.eclipse.jetty.http2.client.HTTP2Client;
import org.eclipse.jetty.http2.frames.HeadersFrame;
import org.eclipse.jetty.server.handler.ContextHandler;
import org.eclipse.jetty.util.Jetty;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
@ExtendWith(DropwizardExtensionsSupport.class)
class StripContentLengthOnConnectFilterTest {
private static final String WEBSOCKET_WITH_FILTER_PATH = "/websocket/filtered";
private static final String WEBSOCKET_WITHOUT_FILTER_PATH = "/websocket/unfiltered";
private static final DropwizardAppExtension<Configuration> EXTENSION =
new DropwizardAppExtension<>(TestApplicationWithFilter.class, null,
ConfigOverride.config("server.applicationConnectors[0].type", "h2c"),
ConfigOverride.config("server.applicationConnectors[0].port", "0"));
@Test
void contentLengthIsStrippedOnUpgrade() throws Exception {
final HttpFields fields = sendUpgradeRequest(WEBSOCKET_WITH_FILTER_PATH);
assertNull(fields.getField("content-length"));
}
@Test
void contentLengthIncorrectlyIncludedOnUpgrade() throws Exception {
final HttpFields fields = sendUpgradeRequest(WEBSOCKET_WITHOUT_FILTER_PATH);
assertNotNull(fields.getField("content-length"), """
If this fails, our jetty version no longer includes errant content-lengths on H2 connect responses.
StripContentLengthOnConnectFilter can now be removed.
""");
}
@Test
void versionCheck() {
assertEquals("12.1.5", Jetty.VERSION, "This class can be removed with https://github.com/jetty/jetty.project/issues/15074, likely 12.1.10");
}
private static HttpFields sendUpgradeRequest(final String path) throws Exception {
final int port = EXTENSION.getLocalPort();
try (final HTTP2Client client = new HTTP2Client()) {
client.start();
final Session session =
client.connect(new InetSocketAddress("localhost", port), new Session.Listener() {}).join();
final HttpFields requestFields = HttpFields.build().put("sec-websocket-version", "13");
final HostPortHttpField hostPort = new HostPortHttpField("localhost:" + port);
final MetaData.ConnectRequest connect =
new MetaData.ConnectRequest(HttpScheme.HTTP, hostPort, path, requestFields, "websocket");
final HeadersFrame headersFrame = new HeadersFrame(connect, null, false);
final CompletableFuture<MetaData.Response> responseFuture = new CompletableFuture<>();
Stream.Listener streamListener = new Stream.Listener() {
@Override
public void onHeaders(Stream stream, HeadersFrame frame) {
if (frame.getMetaData().isResponse()) {
responseFuture.complete((MetaData.Response) frame.getMetaData());
}
}
};
session.newStream(headersFrame, streamListener).get(5, TimeUnit.SECONDS);
final MetaData.Response response = responseFuture.get(5, TimeUnit.SECONDS);
return response.getHttpFields();
}
}
private static class NoopWebSocket implements org.eclipse.jetty.websocket.api.Session.Listener.AutoDemanding {}
public static class TestApplicationWithFilter extends Application<Configuration> {
@Override
public void run(final Configuration configuration, final Environment environment) throws Exception {
final JettyWebSocketCreator jwsc = (_, _) -> new NoopWebSocket();
JettyWebSocketServletContainerInitializer.configure(
environment.getApplicationContext(),
(servletContext, container) -> {
container.addMapping(WEBSOCKET_WITH_FILTER_PATH, jwsc);
container.addMapping(WEBSOCKET_WITHOUT_FILTER_PATH, jwsc);
ensureFilter(servletContext, WEBSOCKET_WITH_FILTER_PATH, StripContentLengthOnConnectFilter.class);
});
}
}
private static void ensureFilter(
final ServletContext servletContext,
final String pathSpec,
final Class<? extends Filter> filterClass) {
final ContextHandler contextHandler = ServletContextHandler.getServletContextHandler(servletContext);
final ServletHandler servletHandler = contextHandler.getDescendant(ServletHandler.class);
final FilterHolder holder = new FilterHolder(filterClass);
final FilterMapping mapping = new FilterMapping();
mapping.setFilterName(holder.getName());
mapping.setPathSpec(pathSpec);
mapping.setDispatcherTypes(EnumSet.of(DispatcherType.REQUEST));
servletHandler.prependFilter(holder);
servletHandler.prependFilterMapping(mapping);
}
}

View File

@ -0,0 +1,25 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX License Identifier: AGPL 3.0 only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.util.ResourceLeakDetector;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
public abstract class AbstractLeakDetectionTest {
private static ResourceLeakDetector.Level originalResourceLeakDetectorLevel;
@BeforeAll
static void setLeakDetectionLevel() {
originalResourceLeakDetectorLevel = ResourceLeakDetector.getLevel();
ResourceLeakDetector.setLevel(ResourceLeakDetector.Level.PARANOID);
}
@AfterAll
static void restoreLeakDetectionLevel() {
ResourceLeakDetector.setLevel(originalResourceLeakDetectorLevel);
}
}

View File

@ -0,0 +1,35 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import io.netty.channel.WriteBufferWaterMark;
import io.netty.channel.embedded.EmbeddedChannel;
import java.nio.charset.StandardCharsets;
import org.junit.jupiter.api.Test;
class H2FrameProxyHandlerTest extends AbstractLeakDetectionTest {
@Test
void proxyWritabilityChanged() {
final EmbeddedChannel target = new EmbeddedChannel();
final EmbeddedChannel source = new EmbeddedChannel(new H2FrameProxyHandler(target, "test"));
// Set a tiny watermark to guarantee an unflushed write sets to unwritable, and then buffer some data
final byte[] bufferedData = "8 bytes!".getBytes(StandardCharsets.UTF_8);
target.config().setWriteBufferWaterMark(new WriteBufferWaterMark(4, 8));
target.write(bufferedData);
assertNull(target.readOutbound(), "nothing should be written without a flush");
assertFalse(target.isWritable(), "target should be unwritable because we've buffered more than the high watermark");
assertFalse(source.config().isAutoRead(), "source should not read because the target is unwritable");
target.flush();
assertTrue(target.isWritable(), "after a flush, the target should be writable");
assertTrue(source.config().isAutoRead(), "after the target becomes writable, autoRead should be enabled");
}
}

View File

@ -0,0 +1,600 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import io.grpc.netty.shaded.io.netty.channel.ChannelOption;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleUserEventChannelHandler;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalServerChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioChannelOption;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.haproxy.HAProxyCommand;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.codec.haproxy.HAProxyMessageEncoder;
import io.netty.handler.codec.haproxy.HAProxyProtocolVersion;
import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol;
import io.netty.handler.codec.http2.DefaultHttp2DataFrame;
import io.netty.handler.codec.http2.DefaultHttp2GoAwayFrame;
import io.netty.handler.codec.http2.DefaultHttp2Headers;
import io.netty.handler.codec.http2.DefaultHttp2HeadersFrame;
import io.netty.handler.codec.http2.DefaultHttp2ResetFrame;
import io.netty.handler.codec.http2.Http2DataFrame;
import io.netty.handler.codec.http2.Http2Error;
import io.netty.handler.codec.http2.Http2FrameCodecBuilder;
import io.netty.handler.codec.http2.Http2Headers;
import io.netty.handler.codec.http2.Http2HeadersFrame;
import io.netty.handler.codec.http2.Http2MultiplexHandler;
import io.netty.handler.codec.http2.Http2ResetFrame;
import io.netty.handler.codec.http2.Http2Settings;
import io.netty.handler.codec.http2.Http2StreamChannel;
import io.netty.handler.codec.http2.Http2StreamChannelBootstrap;
import io.netty.handler.ssl.ApplicationProtocolConfig;
import io.netty.handler.ssl.ApplicationProtocolNames;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.util.ReferenceCountUtil;
import java.io.InputStream;
import java.net.InetSocketAddress;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.IntStream;
import javax.annotation.Nullable;
import javax.net.ssl.SSLException;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
class OmnibusH2ServerTest extends AbstractLeakDetectionTest {
private static final String KEYSTORE_PASSWORD = "password";
// Paths that start with PREFIX should go to the prefix backend, everything else to default.
private static final String PREFIX_BACKEND_IDENTITY = "prefix-backend";
private static final String PREFIX = "/v1/prefix";
private static final String DEFAULT_BACKEND_IDENTITY = "default-backend";
private final NioEventLoopGroup nioEventLoopGroup = new NioEventLoopGroup();
private final DefaultEventLoopGroup localEventLoopGroup = new DefaultEventLoopGroup();
private Channel defaultBackend;
private Channel prefixBackend;
private CompletableFuture<Channel> backendConnection;
private OmnibusH2Server server;
@BeforeEach
void setUp() throws Exception {
// Start two H2C backend servers that echo a response with their identity
defaultBackend = startH2CServer(true, DEFAULT_BACKEND_IDENTITY);
prefixBackend = startH2CServer(false, PREFIX_BACKEND_IDENTITY);
backendConnection = new CompletableFuture<>();
// self-signed TLS context for the frontend loaded from test keyStore
final InputStream keyStore = OmnibusH2ServerTest.class.getResourceAsStream("omnibus-h2-server-test-keystore.p12");
server = new OmnibusH2Server(
SniMapper.buildSniMapping(keyStore, KEYSTORE_PASSWORD),
nioEventLoopGroup,
localEventLoopGroup,
new InetSocketAddress("127.0.0.1", 0),
new OmnibusRouter(
List.of(new OmnibusRouter.OmnibusRoute(PREFIX, prefixBackend.localAddress())),
defaultBackend.localAddress()),
Duration.ofMinutes(1));
server.start();
}
@AfterEach
void tearDown() throws Exception {
server.stop();
defaultBackend.close().sync();
prefixBackend.close().sync();
localEventLoopGroup.shutdownGracefully(1, 1000, TimeUnit.MILLISECONDS).sync();
nioEventLoopGroup.shutdownGracefully(1, 1000, TimeUnit.MILLISECONDS).sync();
}
@Test
void defaultBackend() {
final String response = sendRequestThroughOmnibus("/a/different/path");
assertEquals(DEFAULT_BACKEND_IDENTITY, response);
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void forwardedForHeader(final boolean usePpv2) {
final String expectedSource = usePpv2 ? "127.0.0.123" : "127.0.0.1";
final HAProxyMessage proxyMessage = usePpv2
? new HAProxyMessage(HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4, expectedSource, "127.0.0.2", 1234, 5678)
: null;
final Channel h2Connection = connectToOmnibus(null, proxyMessage);
final String xForwardedFor = sendRequestThroughOmnibus(h2Connection, "/forwarded-for");
assertEquals(expectedSource, xForwardedFor);
}
@ParameterizedTest
@ValueSource(strings = {"/v1/prefix", "/v1/prefix/", "/v1/prefix/other"})
void prefixBackend(final String path) {
final String response = sendRequestThroughOmnibus(path);
assertEquals(PREFIX_BACKEND_IDENTITY, response);
}
@Test
void multipleStreamsOnSameConnection() {
final Channel h2Connection = connectToOmnibus();
final int numStreams = 10;
// Create concurrent streams to both backends on the same connection simultaneously
@SuppressWarnings("rawtypes")
final CompletableFuture[] futures = IntStream.range(0, numStreams)
.mapToObj(i -> CompletableFuture.supplyAsync(() ->
sendRequestThroughOmnibus(
h2Connection,
i % 2 == 0 ? PREFIX : "/v1/other")))
.toArray(CompletableFuture[]::new);
// Ensure we get the response from the correct backend for each stream
CompletableFuture.allOf(futures).join();
for (int i = 0; i < numStreams; i++) {
assertEquals(
i % 2 == 0 ? PREFIX_BACKEND_IDENTITY : DEFAULT_BACKEND_IDENTITY,
futures[i].resultNow());
}
}
@Test
void backendDownStreamReset() {
// Kill the default backend so connection attempts from the omnibus fail
defaultBackend.close().syncUninterruptibly();
final Channel h2Connection = connectToOmnibus();
final CompletableFuture<Http2HeadersFrame> headersFuture = new CompletableFuture<>();
final Http2StreamChannel stream = new Http2StreamChannelBootstrap(h2Connection)
.handler(new HeadersCollectorHandler(headersFuture))
.open()
.syncUninterruptibly()
.getNow();
final Http2Headers headers = new DefaultHttp2Headers()
.method("POST")
.path("/test")
.scheme("https")
.authority("localhost");
stream.writeAndFlush(new DefaultHttp2HeadersFrame(headers, true));
final Http2HeadersFrame responseHeaders = headersFuture.join();
assertEquals("502", responseHeaders.headers().status().toString());
// Stream is dead, but connection should stay alive
assertFalse(
h2Connection.closeFuture().awaitUninterruptibly(5, TimeUnit.MILLISECONDS),
"connection should stay open");
assertTrue(h2Connection.isOpen());
h2Connection.close().syncUninterruptibly();
}
@ParameterizedTest
@ValueSource(strings = {"/goaway", "/reset"})
void backendCloseClosesClientStream(final String path) {
final Channel h2Connection = connectToOmnibus();
final CompletableFuture<Http2ResetFrame> resetFuture = new CompletableFuture<>();
final Http2StreamChannel stream = new Http2StreamChannelBootstrap(h2Connection)
.handler(new RstCollectorHandler(resetFuture))
.open()
.syncUninterruptibly()
.getNow();
final Http2Headers headers = new DefaultHttp2Headers()
.method("POST")
// Triggers the server stream handler to either GOAWAY+close or send an RST based on the path
.path(path)
.scheme("https")
.authority("localhost");
stream.writeAndFlush(new DefaultHttp2HeadersFrame(headers, true));
assertEquals(Http2Error.CANCEL.code(), resetFuture.join().errorCode());
// client<->omnibus h2 connection stays open after a backend close/rst
assertTrue(h2Connection.isActive());
}
@Test
void queuedDataFrames() throws Exception {
final Channel h2Connection = connectToOmnibus();
final CompletableFuture<String> responseFuture = new CompletableFuture<>();
final Http2StreamChannel stream = new Http2StreamChannelBootstrap(h2Connection)
.handler(new ResponseCollectorHandler(responseFuture))
.open()
.syncUninterruptibly()
.getNow();
final Http2Headers headers = new DefaultHttp2Headers()
.method("POST")
.path("/test")
.scheme("https")
.authority("localhost");
final int numFrames = 64;
// Omnibus should handle queueing up frames while connecting to the backend if we blast all the frames right away
stream.write(new DefaultHttp2HeadersFrame(headers, false));
final StringBuilder expectedBuilder = new StringBuilder();
for (int i = 0; i < numFrames; i++) {
final String chunk = String.format("chunk-%03d;", i);
expectedBuilder.append(chunk);
final boolean endStream = (i == numFrames - 1);
stream.write(new DefaultHttp2DataFrame(
Unpooled.copiedBuffer(chunk, StandardCharsets.UTF_8), endStream));
}
stream.flush();
final String expected = expectedBuilder.toString();
final String response = responseFuture.get(10, TimeUnit.SECONDS);
assertEquals(expected, response);
h2Connection.close().syncUninterruptibly();
}
@Test
void clientDisconnectClosesBackendConnections() throws Exception {
final Channel h2Connection = connectToOmnibus();
final Http2StreamChannelBootstrap streamBootstrap = new Http2StreamChannelBootstrap(h2Connection);
final Http2StreamChannel stream = streamBootstrap
.handler(new ResponseCollectorHandler(new CompletableFuture<>()))
.open()
.syncUninterruptibly()
.getNow();
// Write an endStream=false header so the stream stays open
final Http2Headers headers = new DefaultHttp2Headers()
.method("POST")
.path("/test")
.scheme("https")
.authority("localhost");
stream.writeAndFlush(new DefaultHttp2HeadersFrame(headers, false)).syncUninterruptibly();
final Channel backendServerChannel = backendConnection.join();
assertFalse(
backendServerChannel.closeFuture().awaitUninterruptibly(10, TimeUnit.MILLISECONDS),
"Channel should be open");
// All backend connections the omnibus opened on behalf of this client should close if we disconnect the client
h2Connection.close().syncUninterruptibly();
assertTrue(backendServerChannel.closeFuture().await(5, TimeUnit.SECONDS));
}
@Test
void backpressure() throws ExecutionException, InterruptedException, TimeoutException {
final Channel h2Connection = connectToOmnibus();
// We'll take the client channel becoming unwritable as backpressure signal
final AtomicBoolean isWritable = new AtomicBoolean(true);
final CompletableFuture<Http2HeadersFrame> response = new CompletableFuture<>();
final Http2StreamChannel stream = new Http2StreamChannelBootstrap(h2Connection)
.handler(new ChannelInboundHandlerAdapter() {
@Override
public void channelWritabilityChanged(ChannelHandlerContext ctx) {
isWritable.set(ctx.channel().isWritable());
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
try {
if (msg instanceof Http2HeadersFrame headers) {
response.complete(headers);
}
} finally {
ReferenceCountUtil.release(msg);
}
}
})
.open()
.syncUninterruptibly()
.getNow();
final Http2Headers headers = new DefaultHttp2Headers().method("POST").path("/test");
stream.writeAndFlush(new DefaultHttp2HeadersFrame(headers, false)).syncUninterruptibly();
// Make the backend H2 stream processor 'slow' by disabling auto-read on it
final Channel backendServerChannel = backendConnection.join();
backendServerChannel.config().setAutoRead(false);
final byte[] chunk = new byte[16384];
do {
// Write data until our own client hits the high watermark
while (isWritable.get()) {
// Try to wait until the write finishes but if it can't that's fine: we're trying to induce backpressure
stream
.writeAndFlush(new DefaultHttp2DataFrame(Unpooled.wrappedBuffer(chunk), false))
.awaitUninterruptibly(100, TimeUnit.MILLISECONDS);
Thread.yield();
}
// Make sure our channel is still unwritable for a bit since we haven't re-enabled auto-read yet. If we become
// writable it means we are hitting a lower watermark somewhere earlier in the stack, so we can try writing some
// more. Eventually all intermediate channels should flush and we should be stuck on the backend channel which
// will never make progress (because auto-read is disabled)
Thread.sleep(100);
} while (isWritable.get());
stream.writeAndFlush(new DefaultHttp2DataFrame(Unpooled.wrappedBuffer(chunk), true));
// Now re-enable reads on the backend, which should eventually unblock our writes
backendConnection.resultNow().config().setAutoRead(true);
// Now we should eventually be able to send the last (endStream=true) write and get a response
assertEquals("200", response.get(5, TimeUnit.SECONDS).headers().status().toString());
assertTrue(isWritable.get());
h2Connection.close().syncUninterruptibly();
}
@Test
void idleTest() throws Exception {
final InputStream keyStore = OmnibusH2ServerTest.class.getResourceAsStream("omnibus-h2-server-test-keystore.p12");
final Duration timeout = Duration.ofMillis(500);
final OmnibusH2Server timeoutServer = new OmnibusH2Server(
SniMapper.buildSniMapping(keyStore, KEYSTORE_PASSWORD),
nioEventLoopGroup,
localEventLoopGroup,
new InetSocketAddress("127.0.0.1", 0),
new OmnibusRouter(
List.of(new OmnibusRouter.OmnibusRoute(PREFIX, prefixBackend.localAddress())),
defaultBackend.localAddress()),
timeout);
timeoutServer.start();
final Channel channel = connectToOmnibus(timeoutServer, null);
// Send a request to make sure idleTimeouts work even after a stream / backend connection has been established
sendRequestThroughOmnibus(channel, "/a/different/path");
// The server should eventually close this idle connection
assertTrue(channel.closeFuture().awaitUninterruptibly(timeout.toMillis() * 5, TimeUnit.MILLISECONDS));
timeoutServer.stop();
}
private Channel startH2CServer(final boolean local, final String identity) throws InterruptedException {
final EventLoopGroup eventLoopGroup = local ? localEventLoopGroup : nioEventLoopGroup;
return new ServerBootstrap()
.group(eventLoopGroup, eventLoopGroup)
.channel(local ? LocalServerChannel.class : NioServerSocketChannel.class)
// Limit size of kernel TCP buffers to make it easier to hit backpressure in tests
.option(NioChannelOption.SO_RCVBUF, 8192)
.option(NioChannelOption.SO_SNDBUF, 8192)
.childHandler(new ChannelInitializer<>() {
@Override
protected void initChannel(final Channel ch) {
backendConnection.complete(ch);
ch.pipeline().addLast(Http2FrameCodecBuilder.forServer().build());
ch.pipeline().addLast(new Http2MultiplexHandler(new ChannelInitializer<Http2StreamChannel>() {
@Override
protected void initChannel(final Http2StreamChannel ch) {
ch.pipeline().addLast(new TestHandler(identity));
}
}));
}
})
.bind(local ? new LocalAddress(identity) : new InetSocketAddress("127.0.0.1", 0))
.sync()
.channel();
}
private Channel connectToOmnibus() {
return connectToOmnibus(null, null);
}
/// Makes an H2 connection to the omnibus at [this#server] on which new H2 streams can be opened
private Channel connectToOmnibus(@Nullable OmnibusH2Server server, @Nullable final HAProxyMessage proxyHeader) {
if (server == null) {
server = this.server;
}
final SslContext clientSsl;
try {
clientSsl = SslContextBuilder.forClient()
.trustManager(InsecureTrustManagerFactory.INSTANCE)
.applicationProtocolConfig(new ApplicationProtocolConfig(
ApplicationProtocolConfig.Protocol.ALPN,
ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE,
ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT,
ApplicationProtocolNames.HTTP_2))
.build();
} catch (SSLException e) {
throw new RuntimeException(e);
}
final Bootstrap clientBootstrap = new Bootstrap()
.group(nioEventLoopGroup)
.channel(NioSocketChannel.class)
// Limit size of kernel TCP buffers to make it easier to hit backpressure in tests
.option(NioChannelOption.SO_RCVBUF, 8192)
.option(NioChannelOption.SO_SNDBUF, 8192)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(final SocketChannel ch) {
ch.pipeline().addLast(HAProxyMessageEncoder.INSTANCE);
}
});
final Channel ch = clientBootstrap.connect(server.getLocalAddress())
.syncUninterruptibly()
.channel();
if (proxyHeader != null) {
ch.writeAndFlush(proxyHeader).syncUninterruptibly();
}
ch.pipeline().remove(HAProxyMessageEncoder.INSTANCE);
ch.pipeline().addLast(clientSsl.newHandler(ch.alloc(), server.getLocalAddress().getHostName(), server.getLocalAddress().getPort()));
ch.pipeline().addLast(Http2FrameCodecBuilder.forClient()
.initialSettings(Http2Settings.defaultSettings())
.build());
ch.pipeline().addLast(new Http2MultiplexHandler(new ChannelInboundHandlerAdapter()));
return ch;
}
/// Sends an H2 request to [this#server] and returns the H2 response body
private String sendRequestThroughOmnibus(final String path) {
final Channel h2Connection = connectToOmnibus();
final String result = sendRequestThroughOmnibus(h2Connection, path);
h2Connection.close().syncUninterruptibly();
return result;
}
private String sendRequestThroughOmnibus(final Channel h2Connection, final String path) {
final CompletableFuture<String> responseFuture = new CompletableFuture<>();
final Http2StreamChannelBootstrap streamBootstrap = new Http2StreamChannelBootstrap(h2Connection);
final Http2StreamChannel stream = streamBootstrap
.handler(new ResponseCollectorHandler(responseFuture))
.open()
.syncUninterruptibly()
.getNow();
final Http2Headers headers = new DefaultHttp2Headers()
.method("POST")
.path(path)
.scheme("https")
.authority("localhost");
stream.writeAndFlush(new DefaultHttp2HeadersFrame(headers, true));
return responseFuture.join();
}
/// A backend that either echos the request body, returns an identity, or disconnects based on the request
private static class TestHandler extends ChannelInboundHandlerAdapter {
// Returned if request has no body
private final String identity;
private final ByteBuf accumulated = Unpooled.buffer();
private TestHandler(final String identity) {
this.identity = identity;
}
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
if (msg instanceof Http2HeadersFrame headers) {
final String path = headers.headers().path().toString();
if (path.contains("reset")) {
ctx.writeAndFlush(new DefaultHttp2ResetFrame(Http2Error.NO_ERROR))
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
} else if (path.contains("goaway")) {
ctx.channel().parent()
.writeAndFlush(new DefaultHttp2GoAwayFrame(Http2Error.NO_ERROR))
.addListener(ChannelFutureListener.CLOSE);
} else if (path.contains("forwarded-for")) {
final String xForwardedFor = Optional
.ofNullable(headers.headers().get("x-forwarded-for"))
.map(CharSequence::toString)
.orElse("");
writeResponse(ctx, Unpooled.copiedBuffer(xForwardedFor, StandardCharsets.UTF_8));
} else if (headers.isEndStream()) {
writeResponse(ctx, Unpooled.copiedBuffer(identity, StandardCharsets.UTF_8));
}
} else if (msg instanceof Http2DataFrame dataFrame) {
accumulated.writeBytes(dataFrame.content());
if (dataFrame.isEndStream()) {
writeResponse(ctx, accumulated);
}
}
ReferenceCountUtil.release(msg);
}
private void writeResponse(final ChannelHandlerContext ctx, final ByteBuf body) {
final Http2Headers responseHeaders = new DefaultHttp2Headers().status("200");
ctx.write(new DefaultHttp2HeadersFrame(responseHeaders, false));
ctx.writeAndFlush(new DefaultHttp2DataFrame(body, true));
}
}
/// Completes the provided future with the first [Http2DataFrame] received
private static class ResponseCollectorHandler extends ChannelInboundHandlerAdapter {
private final CompletableFuture<String> responseFuture;
ResponseCollectorHandler(final CompletableFuture<String> responseFuture) {
this.responseFuture = responseFuture;
}
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
if (msg instanceof Http2DataFrame dataFrame) {
responseFuture.complete(dataFrame.content().toString(StandardCharsets.UTF_8));
}
ReferenceCountUtil.release(msg);
}
@Override
public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) {
responseFuture.completeExceptionally(cause);
}
}
/// Completes the provided future with the first [Http2HeadersFrame] received
private static class HeadersCollectorHandler extends ChannelInboundHandlerAdapter {
private final CompletableFuture<Http2HeadersFrame> responseFuture;
HeadersCollectorHandler(final CompletableFuture<Http2HeadersFrame> responseFuture) {
this.responseFuture = responseFuture;
}
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
if (msg instanceof Http2HeadersFrame headers) {
responseFuture.complete(headers);
}
ReferenceCountUtil.release(msg);
}
@Override
public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) {
responseFuture.completeExceptionally(cause);
}
}
/// Completes the provided future when an RST frame is received or errors if we don't get one
private static class RstCollectorHandler extends SimpleUserEventChannelHandler<Http2ResetFrame> {
private final CompletableFuture<Http2ResetFrame> resetFuture;
private RstCollectorHandler(final CompletableFuture<Http2ResetFrame> resetFuture) {
this.resetFuture = resetFuture;
}
@Override
protected void eventReceived(final ChannelHandlerContext ctx, final Http2ResetFrame evt) {
resetFuture.complete(evt);
}
@Override
public void channelInactive(final ChannelHandlerContext ctx) {
if (!resetFuture.isDone()) {
resetFuture.completeExceptionally(new IllegalStateException("Channel went inactive without RST"));
}
}
}
}

View File

@ -0,0 +1,53 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import io.netty.buffer.ByteBuf;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.haproxy.HAProxyCommand;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.codec.haproxy.HAProxyMessageEncoder;
import io.netty.handler.codec.haproxy.HAProxyProtocolVersion;
import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol;
import org.junit.jupiter.api.Test;
class ProxyProtocolHandlerTest extends AbstractLeakDetectionTest {
private static final HAProxyMessage PROXY_MESSAGE = new HAProxyMessage(
HAProxyProtocolVersion.V2, HAProxyCommand.PROXY,
HAProxyProxiedProtocol.TCP4, "10.0.0.1", "10.0.0.2", 1234, 5678);
@Test
void sendHeader() {
final EmbeddedChannel encoder = new EmbeddedChannel(HAProxyMessageEncoder.INSTANCE);
encoder.writeOutbound(PROXY_MESSAGE.retain());
final ByteBuf ppv2Bytes = encoder.readOutbound();
final EmbeddedChannel ch = new EmbeddedChannel(new ProxyProtocolHandler());
ch.writeInbound(ppv2Bytes);
final HAProxyMessage actual = ch.readInbound();
assertEquals(PROXY_MESSAGE.protocolVersion(), actual.protocolVersion());
assertEquals(PROXY_MESSAGE.sourceAddress(), actual.sourceAddress());
}
@Test
void sendHeaderSlowly() {
final EmbeddedChannel encoder = new EmbeddedChannel(HAProxyMessageEncoder.INSTANCE);
encoder.writeOutbound(PROXY_MESSAGE.retain());
final ByteBuf ppv2Bytes = encoder.readOutbound();
final EmbeddedChannel ch = new EmbeddedChannel(new ProxyProtocolHandler());
while (ppv2Bytes.isReadable()) {
assertNull(ch.readInbound());
ch.writeInbound(ppv2Bytes.readBytes(1));
}
final HAProxyMessage actual = ch.readInbound();
assertEquals(PROXY_MESSAGE.protocolVersion(), actual.protocolVersion());
assertEquals(PROXY_MESSAGE.sourceAddress(), actual.sourceAddress());
}
}

View File

@ -0,0 +1,179 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.SimpleUserEventChannelHandler;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import io.netty.handler.ssl.SniHandler;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.util.Mapping;
import java.io.InputStream;
import java.security.cert.X509Certificate;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SNIHostName;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSession;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
class SniMapperTest {
// Configuration for precomputed keystore blob defined in sni-mapper-test-keystore.p12
private static final String FOO_DOMAIN = "foo.example.com";
private static final String BAR_DOMAIN = "bar.example.com";
private static final String KEY_STORE_PASSWORD = "password";
private static final String KEY_STORE_NAME = "sni-mapper-test-keystore.p12";
private DefaultEventLoopGroup eventLoopGroup;
private Channel serverChannel;
@BeforeEach
void setUp() throws Exception {
final InputStream keyStore = SniMapper.class.getResourceAsStream(KEY_STORE_NAME);
eventLoopGroup = new DefaultEventLoopGroup();
final Mapping<String, SslContext> sniMapping =
SniMapper.buildSniMapping(keyStore, KEY_STORE_PASSWORD);
final LocalAddress localAddress = new LocalAddress(SniMapper.class.getSimpleName());
serverChannel = new ServerBootstrap()
.group(eventLoopGroup)
.channel(LocalServerChannel.class)
.childHandler(new ChannelInitializer<>() {
@Override
protected void initChannel(final Channel ch) {
ch.pipeline().addLast(new SniHandler(sniMapping));
}
})
.bind(localAddress)
.sync()
.channel();
}
@AfterEach
void tearDown() throws Exception {
if (serverChannel != null) {
serverChannel.close().sync();
}
eventLoopGroup.shutdownGracefully(1, 1000, TimeUnit.MILLISECONDS).sync();
}
@Test
void unknownDomain() throws Exception {
final InputStream keyStore = SniMapper.class.getResourceAsStream(KEY_STORE_NAME);
final Mapping<String, SslContext> sniMapping = SniMapper.buildSniMapping(keyStore, KEY_STORE_PASSWORD);
assertNotNull(sniMapping.map("unknown.example.com"));
final X509Certificate defaultCertificate = connectAndGetServerCertificate("unknown.example.com", null);
// bar.example.com is the lexicographically first domain, so we should default to it.
assertCertificateIsForDomain(defaultCertificate, BAR_DOMAIN);
}
static List<Arguments> selectCertificate() {
return List.of(
Arguments.of(FOO_DOMAIN, List.of(), "Ed25519"),
Arguments.of(BAR_DOMAIN, List.of(), "Ed25519"),
Arguments.of(BAR_DOMAIN, List.of("ed25519"), "Ed25519"),
Arguments.of(FOO_DOMAIN, List.of("rsa_pss_rsae_sha256", "rsa_pss_rsae_sha384", "rsa_pss_rsae_sha512"), "SHA256withRSA"),
Arguments.of(FOO_DOMAIN, List.of("rsa_pss_rsae_sha256", "rsa_pss_rsae_sha384", "rsa_pss_rsae_sha512", "ed25519"), "SHA256withRSA"),
Arguments.of(FOO_DOMAIN, List.of("ed25519", "rsa_pss_rsae_sha256", "rsa_pss_rsae_sha384", "rsa_pss_rsae_sha512"), "Ed25519")
);
}
@ParameterizedTest
@MethodSource
void selectCertificate(final String sni, final List<String> signatureSchemes, final String expectedSigAlgorithm)
throws Exception {
final X509Certificate serverCert = connectAndGetServerCertificate(sni, signatureSchemes.toArray(String[]::new));
assertNotNull(serverCert);
assertCertificateIsForDomain(serverCert, sni);
assertEquals(expectedSigAlgorithm, serverCert.getSigAlgName());
}
private X509Certificate connectAndGetServerCertificate(final String sniHostname,
final String[] signatureSchemes) throws Exception {
final SslContext clientSsl = SslContextBuilder.forClient()
.trustManager(InsecureTrustManagerFactory.INSTANCE)
.protocols("TLSv1.3")
.build();
final CompletableFuture<X509Certificate> certFuture = new CompletableFuture<>();
final Bootstrap clientBootstrap = new Bootstrap()
.group(eventLoopGroup)
.channel(LocalChannel.class)
.handler(new ChannelInitializer<LocalChannel>() {
@Override
protected void initChannel(final LocalChannel ch) {
final SSLEngine engine = clientSsl.newEngine(ch.alloc());
final SSLParameters params = engine.getSSLParameters();
params.setServerNames(List.of(new SNIHostName(sniHostname)));
if (signatureSchemes != null && signatureSchemes.length != 0) {
params.setSignatureSchemes(signatureSchemes);
}
engine.setSSLParameters(params);
final SslHandler sslHandler = new SslHandler(engine);
ch.pipeline().addLast(sslHandler);
ch.pipeline().addLast(new SimpleUserEventChannelHandler<SslHandshakeCompletionEvent>() {
@Override
protected void eventReceived(final ChannelHandlerContext ctx, final SslHandshakeCompletionEvent evt) {
if (!evt.isSuccess()) {
certFuture.completeExceptionally(evt.cause());
return;
}
try {
final SSLSession session = sslHandler.engine().getSession();
final X509Certificate cert = (X509Certificate) session.getPeerCertificates()[0];
certFuture.complete(cert);
} catch (final SSLPeerUnverifiedException e) {
certFuture.completeExceptionally(e);
}
}
});
}
});
final Channel clientChannel = clientBootstrap.connect(serverChannel.localAddress()).sync().channel();
try {
return certFuture.get(5, TimeUnit.SECONDS);
} finally {
clientChannel.close().sync();
}
}
private static void assertCertificateIsForDomain(final X509Certificate cert, final String expectedDomain)
throws Exception {
assertTrue(cert.getSubjectAlternativeNames().stream()
.filter(san -> (int) san.getFirst() == 2) // dNSName
.map(san -> (String) san.get(1))
.anyMatch(name -> name.equalsIgnoreCase(expectedDomain)));
}
}

View File

@ -28,7 +28,6 @@ import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Tag;
import jakarta.annotation.Priority;
import jakarta.servlet.DispatcherType;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.InternalServerErrorException;
import jakarta.ws.rs.NotAuthorizedException;
@ -45,7 +44,6 @@ import java.io.IOException;
import java.net.URI;
import java.security.Principal;
import java.time.Duration;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
@ -55,25 +53,28 @@ import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import java.util.stream.Stream;
import javax.security.auth.Subject;
import org.eclipse.jetty.server.Connector;
import org.eclipse.jetty.server.HttpChannel;
import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.server.Handler;
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.util.component.Container;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.handler.EventsHandler;
import org.eclipse.jetty.util.component.LifeCycle;
import org.eclipse.jetty.websocket.api.Callback;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.filters.PriorityFilter;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
@ -81,7 +82,8 @@ import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ExtendWith(DropwizardExtensionsSupport.class)
class MetricsHttpChannelListenerIntegrationTest {
@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
class MetricsHttpEventHandlerIntegrationTest {
private static final TrafficSource TRAFFIC_SOURCE = TrafficSource.HTTP;
private static final MeterRegistry METER_REGISTRY = mock(MeterRegistry.class);
@ -91,7 +93,7 @@ class MetricsHttpChannelListenerIntegrationTest {
private static final AtomicReference<CountDownLatch> COUNT_DOWN_LATCH_FUTURE_REFERENCE = new AtomicReference<>();
private static final DropwizardAppExtension<Configuration> EXTENSION = new DropwizardAppExtension<>(
MetricsHttpChannelListenerIntegrationTest.TestApplication.class);
MetricsHttpEventHandlerIntegrationTest.TestApplication.class);
@AfterEach
void teardown() {
@ -112,9 +114,9 @@ class MetricsHttpChannelListenerIntegrationTest {
final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
final Map<String, Counter> counterMap = Map.of(
MetricsHttpChannelListener.REQUEST_COUNTER_NAME, REQUEST_COUNTER,
MetricsHttpChannelListener.RESPONSE_BYTES_COUNTER_NAME, RESPONSE_BYTES_COUNTER,
MetricsHttpChannelListener.REQUEST_BYTES_COUNTER_NAME, REQUEST_BYTES_COUNTER
MetricsHttpEventHandler.REQUEST_COUNTER_NAME, REQUEST_COUNTER,
MetricsHttpEventHandler.RESPONSE_BYTES_COUNTER_NAME, RESPONSE_BYTES_COUNTER,
MetricsHttpEventHandler.REQUEST_BYTES_COUNTER_NAME, REQUEST_BYTES_COUNTER
);
when(METER_REGISTRY.counter(anyString(), any(Iterable.class)))
.thenAnswer(a -> counterMap.getOrDefault(a.getArgument(0, String.class), mock(Counter.class)));
@ -148,7 +150,7 @@ class MetricsHttpChannelListenerIntegrationTest {
assertTrue(countDownLatch.await(1000, TimeUnit.MILLISECONDS));
verify(METER_REGISTRY).counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), tagCaptor.capture());
verify(METER_REGISTRY).counter(eq(MetricsHttpEventHandler.REQUEST_COUNTER_NAME), tagCaptor.capture());
verify(REQUEST_COUNTER).increment();
final Iterable<Tag> tagIterable = tagCaptor.getValue();
@ -159,10 +161,10 @@ class MetricsHttpChannelListenerIntegrationTest {
}
assertEquals(Set.of(
Tag.of(MetricsHttpChannelListener.PATH_TAG, expectedTagPath),
Tag.of(MetricsHttpChannelListener.METHOD_TAG, "GET"),
Tag.of(MetricsHttpChannelListener.STATUS_CODE_TAG, String.valueOf(expectedStatus)),
Tag.of(MetricsHttpChannelListener.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase()),
Tag.of(MetricsHttpEventHandler.PATH_TAG, expectedTagPath),
Tag.of(MetricsHttpEventHandler.METHOD_TAG, "GET"),
Tag.of(MetricsHttpEventHandler.STATUS_CODE_TAG, String.valueOf(expectedStatus)),
Tag.of(MetricsHttpEventHandler.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase()),
Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android")),
tags);
}
@ -186,7 +188,7 @@ class MetricsHttpChannelListenerIntegrationTest {
@Test
void testWebSocketUpgrade() throws Exception {
final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest(URI.create(String.format("ws://localhost:%d%s", EXTENSION.getLocalPort(), "/v1/websocket")));
upgradeRequest.setHeader(HttpHeaders.USER_AGENT, "Signal-Android/4.53.7 (Android 8.1)");
final CountDownLatch countDownLatch = new CountDownLatch(1);
@ -194,24 +196,18 @@ class MetricsHttpChannelListenerIntegrationTest {
final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
final Map<String, Counter> counterMap = Map.of(
MetricsHttpChannelListener.REQUEST_COUNTER_NAME, REQUEST_COUNTER,
MetricsHttpChannelListener.RESPONSE_BYTES_COUNTER_NAME, RESPONSE_BYTES_COUNTER,
MetricsHttpChannelListener.REQUEST_BYTES_COUNTER_NAME, REQUEST_BYTES_COUNTER
MetricsHttpEventHandler.REQUEST_COUNTER_NAME, REQUEST_COUNTER,
MetricsHttpEventHandler.RESPONSE_BYTES_COUNTER_NAME, RESPONSE_BYTES_COUNTER,
MetricsHttpEventHandler.REQUEST_BYTES_COUNTER_NAME, REQUEST_BYTES_COUNTER
);
when(METER_REGISTRY.counter(anyString(), any(Iterable.class)))
.thenAnswer(a -> counterMap.getOrDefault(a.getArgument(0, String.class), mock(Counter.class)));
client.connect(new WebSocketListener() {
@Override
public void onWebSocketConnect(final Session session) {
session.close(1000, "OK");
}
},
URI.create(String.format("ws://localhost:%d%s", EXTENSION.getLocalPort(), "/v1/websocket")), upgradeRequest);
client.connect(new AutoClosingWebSocketSessionListener(), upgradeRequest);
assertTrue(countDownLatch.await(1000, TimeUnit.MILLISECONDS));
verify(METER_REGISTRY).counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), tagCaptor.capture());
verify(METER_REGISTRY).counter(eq(MetricsHttpEventHandler.REQUEST_COUNTER_NAME), tagCaptor.capture());
verify(REQUEST_COUNTER).increment();
final Iterable<Tag> tagIterable = tagCaptor.getValue();
@ -222,10 +218,10 @@ class MetricsHttpChannelListenerIntegrationTest {
}
assertEquals(Set.of(
Tag.of(MetricsHttpChannelListener.PATH_TAG, "/v1/websocket"),
Tag.of(MetricsHttpChannelListener.METHOD_TAG, "GET"),
Tag.of(MetricsHttpChannelListener.STATUS_CODE_TAG, String.valueOf(101)),
Tag.of(MetricsHttpChannelListener.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase()),
Tag.of(MetricsHttpEventHandler.PATH_TAG, "/v1/websocket"),
Tag.of(MetricsHttpEventHandler.METHOD_TAG, "GET"),
Tag.of(MetricsHttpEventHandler.STATUS_CODE_TAG, String.valueOf(101)),
Tag.of(MetricsHttpEventHandler.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase()),
Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android")),
tags);
}
@ -248,9 +244,9 @@ class MetricsHttpChannelListenerIntegrationTest {
final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
final Map<String, Counter> counterMap = Map.of(
MetricsHttpChannelListener.REQUEST_COUNTER_NAME, REQUEST_COUNTER,
MetricsHttpChannelListener.RESPONSE_BYTES_COUNTER_NAME, RESPONSE_BYTES_COUNTER,
MetricsHttpChannelListener.REQUEST_BYTES_COUNTER_NAME, REQUEST_BYTES_COUNTER
MetricsHttpEventHandler.REQUEST_COUNTER_NAME, REQUEST_COUNTER,
MetricsHttpEventHandler.RESPONSE_BYTES_COUNTER_NAME, RESPONSE_BYTES_COUNTER,
MetricsHttpEventHandler.REQUEST_BYTES_COUNTER_NAME, REQUEST_BYTES_COUNTER
);
when(METER_REGISTRY.counter(anyString(), any(Iterable.class)))
.thenAnswer(a -> counterMap.getOrDefault(a.getArgument(0, String.class), mock(Counter.class)));
@ -263,7 +259,7 @@ class MetricsHttpChannelListenerIntegrationTest {
assertTrue(countDownLatch.await(1000, TimeUnit.MILLISECONDS));
verify(METER_REGISTRY).counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), tagCaptor.capture());
verify(METER_REGISTRY).counter(eq(MetricsHttpEventHandler.REQUEST_COUNTER_NAME), tagCaptor.capture());
verify(REQUEST_COUNTER).increment();
final Iterable<Tag> tagIterable = tagCaptor.getValue();
@ -273,7 +269,7 @@ class MetricsHttpChannelListenerIntegrationTest {
tags.add(tag);
}
assertFalse(tags.contains(Tag.of(MetricsHttpChannelListener.METHOD_TAG, "ANNOY")));
assertFalse(tags.contains(Tag.of(MetricsHttpEventHandler.METHOD_TAG, "ANNOY")));
}
}
@ -283,17 +279,16 @@ class MetricsHttpChannelListenerIntegrationTest {
public void run(final Configuration configuration,
final Environment environment) throws Exception {
final MetricsHttpChannelListener metricsHttpChannelListener = new MetricsHttpChannelListener(
METER_REGISTRY,
mock(ClientReleaseManager.class),
Set.of("/v1/websocket")
);
MetricsHttpEventHandler.configure(environment, METER_REGISTRY, mock(ClientReleaseManager.class), Set.of("/v1/websocket"));
metricsHttpChannelListener.configure(environment);
environment.lifecycle().addEventListener(new TestListener(COUNT_DOWN_LATCH_FUTURE_REFERENCE));
environment.servlets().addFilter("RemoteAddressFilter", new RemoteAddressFilter())
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
environment.lifecycle().addEventListener(new LifeCycle.Listener() {
@Override
public void lifeCycleStarting(final LifeCycle event) {
if (event instanceof Server server) {
server.setHandler(new TestListener(server.getHandler(), COUNT_DOWN_LATCH_FUTURE_REFERENCE));
}
}
});
environment.jersey().register(new TestResource());
environment.jersey().register(new TestAuthFilter());
@ -306,14 +301,15 @@ class MetricsHttpChannelListenerIntegrationTest {
webSocketEnvironment.jersey().register(new TestResource());
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
WebSocketResourceProviderFactory<TestPrincipal> webSocketServlet = new WebSocketResourceProviderFactory<>(
webSocketEnvironment, TestPrincipal.class, webSocketConfiguration,
webSocketEnvironment, TestPrincipal.class,
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
environment.servlets().addServlet("WebSocket", webSocketServlet)
.addMapping("/v1/websocket");
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(),
(servletContext, container) -> {
container.addMapping("/v1/websocket", webSocketServlet);
PriorityFilter.ensureFilter(servletContext, new RemoteAddressFilter());
});
}
}
@ -329,36 +325,23 @@ class MetricsHttpChannelListenerIntegrationTest {
}
/**
* A simple listener to signal that {@link HttpChannel.Listener} has completed its work, since its onComplete() is on
* A simple listener to signal that {@link EventsHandler} has completed its work, since its onComplete() is on
* a different thread from the one that sends the response, creating a race condition between the listener and the
* test assertions
*/
static class TestListener implements HttpChannel.Listener, Container.Listener, LifeCycle.Listener {
static class TestListener extends EventsHandler {
private final AtomicReference<CountDownLatch> completableFutureAtomicReference;
TestListener(AtomicReference<CountDownLatch> countDownLatchReference) {
TestListener(final Handler handler, AtomicReference<CountDownLatch> countDownLatchReference) {
super(handler);
this.completableFutureAtomicReference = countDownLatchReference;
}
@Override
public void onComplete(final Request request) {
public void onComplete(Request request, int status, HttpFields headers, Throwable failure) {
completableFutureAtomicReference.get().countDown();
}
@Override
public void beanAdded(final Container parent, final Object child) {
if (child instanceof Connector connector) {
connector.addBean(this);
}
}
@Override
public void beanRemoved(final Container parent, final Object child) {
}
}
@Path("/v1/test")
@ -400,4 +383,11 @@ class MetricsHttpChannelListenerIntegrationTest {
}
}
public static class AutoClosingWebSocketSessionListener implements Session.Listener.AutoDemanding {
@Override
public void onWebSocketOpen(final Session session) {
session.close(1000, "OK", Callback.NOOP);
}
}
}

View File

@ -20,24 +20,33 @@ import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags;
import java.util.Collections;
import java.nio.ByteBuffer;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.HttpURI;
import org.eclipse.jetty.io.Content;
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.server.Response;
import org.glassfish.jersey.server.ContainerRequest;
import org.glassfish.jersey.server.ContainerResponse;
import org.glassfish.jersey.server.ExtendedUriInfo;
import org.glassfish.jersey.uri.UriTemplate;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import javax.annotation.Nullable;
class MetricsHttpChannelListenerTest {
class MetricsHttpEventHandlerTest {
private final static String USER_AGENT = "Signal-Android/6.53.7 (Android 8.1)";
private MeterRegistry meterRegistry;
private Counter requestCounter;
@ -45,7 +54,7 @@ class MetricsHttpChannelListenerTest {
private Counter responseBytesCounter;
private Counter requestBytesCounter;
private ClientReleaseManager clientReleaseManager;
private MetricsHttpChannelListener listener;
private MetricsHttpEventHandler listener;
@BeforeEach
void setup() {
@ -55,27 +64,28 @@ class MetricsHttpChannelListenerTest {
responseBytesCounter = mock(Counter.class);
requestBytesCounter = mock(Counter.class);
when(meterRegistry.counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), any(Iterable.class)))
when(meterRegistry.counter(eq(MetricsHttpEventHandler.REQUEST_COUNTER_NAME), any(Iterable.class)))
.thenReturn(requestCounter);
when(meterRegistry.counter(eq(MetricsHttpChannelListener.REQUESTS_BY_VERSION_COUNTER_NAME), any(Tags.class)))
when(meterRegistry.counter(eq(MetricsHttpEventHandler.REQUESTS_BY_VERSION_COUNTER_NAME), any(Tags.class)))
.thenReturn(requestsByVersionCounter);
when(meterRegistry.counter(eq(MetricsHttpChannelListener.RESPONSE_BYTES_COUNTER_NAME), any(Iterable.class)))
when(meterRegistry.counter(eq(MetricsHttpEventHandler.RESPONSE_BYTES_COUNTER_NAME), any(Iterable.class)))
.thenReturn(responseBytesCounter);
when(meterRegistry.counter(eq(MetricsHttpChannelListener.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class)))
when(meterRegistry.counter(eq(MetricsHttpEventHandler.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class)))
.thenReturn(requestBytesCounter);
clientReleaseManager = mock(ClientReleaseManager.class);
listener = new MetricsHttpChannelListener(meterRegistry, clientReleaseManager, Collections.emptySet());
listener = new MetricsHttpEventHandler(null, meterRegistry, clientReleaseManager, Set.of("/test"));
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
@CartesianTest
@SuppressWarnings("unchecked")
void testRequests(final boolean versionActive) {
void testRequests(@CartesianTest.Values(booleans = {true, false}) final boolean pathFromFilter,
@CartesianTest.Values(booleans = {true, false}) final boolean versionActive) {
final String path = "/test";
final String method = "GET";
final int statusCode = 200;
@ -85,30 +95,40 @@ class MetricsHttpChannelListenerTest {
final Request request = mock(Request.class);
when(request.getMethod()).thenReturn(method);
when(request.getHeader(HttpHeaders.USER_AGENT)).thenReturn("Signal-Android/6.53.7 (Android 8.1)");
final HttpFields.Mutable requestHeaders = HttpFields.build();
requestHeaders.put(HttpHeader.USER_AGENT, USER_AGENT);
when(request.getHeaders()).thenReturn(requestHeaders);
when(request.getHttpURI()).thenReturn(httpUri);
if (pathFromFilter) {
when(request.getAttribute(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME))
.thenReturn(new MetricsHttpEventHandler.RequestInfo(path, method, USER_AGENT));
} else {
when(request.setAttribute(eq(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME), any())).thenAnswer(invocation -> {
when(request.getAttribute(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME))
.thenReturn(invocation.getArgument(1));
return null;
});
}
final Response response = mock(Response.class);
when(response.getStatus()).thenReturn(statusCode);
when(clientReleaseManager.isVersionActive(any(), any())).thenReturn(versionActive);
when(response.getContentCount()).thenReturn(1024L);
when(request.getResponse()).thenReturn(response);
when(request.getContentRead()).thenReturn(512L);
final ExtendedUriInfo extendedUriInfo = mock(ExtendedUriInfo.class);
when(request.getAttribute(MetricsHttpChannelListener.URI_INFO_PROPERTY_NAME)).thenReturn(extendedUriInfo);
when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate(path)));
when(clientReleaseManager.isVersionActive(any(), any())).thenReturn(versionActive);
final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
listener.onComplete(request);
listener.onRequestRead(request, Content.Chunk.from(ByteBuffer.allocate(512), true));
listener.onResponseWrite(request, true, ByteBuffer.allocate(1024));
listener.onComplete(request, statusCode, requestHeaders, null);
verify(requestCounter).increment();
verify(responseBytesCounter).increment(1024L);
verify(requestBytesCounter).increment(512L);
verify(meterRegistry).counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), tagCaptor.capture());
verify(meterRegistry).counter(eq(MetricsHttpEventHandler.REQUEST_COUNTER_NAME), tagCaptor.capture());
final Set<Tag> tags = new HashSet<>();
for (final Tag tag : tagCaptor.getValue()) {
@ -116,10 +136,10 @@ class MetricsHttpChannelListenerTest {
}
final Set<Tag> expectedTags = new HashSet<>(Set.of(
Tag.of(MetricsHttpChannelListener.PATH_TAG, path),
Tag.of(MetricsHttpChannelListener.METHOD_TAG, method),
Tag.of(MetricsHttpChannelListener.STATUS_CODE_TAG, String.valueOf(statusCode)),
Tag.of(MetricsHttpChannelListener.TRAFFIC_SOURCE_TAG, TrafficSource.HTTP.name().toLowerCase()),
Tag.of(MetricsHttpEventHandler.PATH_TAG, path),
Tag.of(MetricsHttpEventHandler.METHOD_TAG, method),
Tag.of(MetricsHttpEventHandler.STATUS_CODE_TAG, String.valueOf(statusCode)),
Tag.of(MetricsHttpEventHandler.TRAFFIC_SOURCE_TAG, TrafficSource.HTTP.name().toLowerCase()),
Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android")));
if (versionActive) {
@ -143,23 +163,22 @@ class MetricsHttpChannelListenerTest {
final Request request = mock(Request.class);
when(request.getMethod()).thenReturn(method);
when(request.getHeader(HttpHeaders.USER_AGENT)).thenReturn("Signal-Android/6.53.7 (Android 8.1)");
final HttpFields.Mutable requestHeaders = HttpFields.build();
requestHeaders.put(HttpHeader.USER_AGENT, USER_AGENT);
when(request.getHeaders()).thenReturn(requestHeaders);
when(request.getHttpURI()).thenReturn(httpUri);
when(request.getAttribute(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME))
.thenReturn(new MetricsHttpEventHandler.RequestInfo(path, method, USER_AGENT));
final Response response = mock(Response.class);
when(response.getStatus()).thenReturn(statusCode);
when(response.getContentCount()).thenReturn(1024L);
when(request.getResponse()).thenReturn(response);
when(request.getContentRead()).thenReturn(512L);
final ExtendedUriInfo extendedUriInfo = mock(ExtendedUriInfo.class);
when(request.getAttribute(MetricsHttpChannelListener.URI_INFO_PROPERTY_NAME)).thenReturn(extendedUriInfo);
when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate(path)));
listener.onComplete(request);
listener.onRequestRead(request, Content.Chunk.from(ByteBuffer.allocate(512), true));
listener.onResponseWrite(request, true, ByteBuffer.allocate(1024));
listener.onComplete(request, statusCode, requestHeaders, null);
if (versionActive) {
final ArgumentCaptor<Tags> tagCaptor = ArgumentCaptor.forClass(Tags.class);
verify(meterRegistry).counter(eq(MetricsHttpChannelListener.REQUESTS_BY_VERSION_COUNTER_NAME),
verify(meterRegistry).counter(eq(MetricsHttpEventHandler.REQUESTS_BY_VERSION_COUNTER_NAME),
tagCaptor.capture());
final Set<Tag> tags = new HashSet<>();
tags.clear();
@ -178,7 +197,7 @@ class MetricsHttpChannelListenerTest {
@ParameterizedTest
@MethodSource
void normalizeMethod(@Nullable final String originalMethod, final String expectedMethod) {
assertEquals(expectedMethod, MetricsHttpChannelListener.normalizeMethod(originalMethod));
assertEquals(expectedMethod, MetricsHttpEventHandler.normalizeMethod(originalMethod));
}
private static List<Arguments> normalizeMethod() {
@ -190,4 +209,37 @@ class MetricsHttpChannelListenerTest {
Arguments.arguments("get", "get")
);
}
@Test
void testResponseFilterSetsRequestInfo() {
final ContainerRequest request = mock(ContainerRequest.class);
final ExtendedUriInfo extendedUriInfo = mock(ExtendedUriInfo.class);
when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate("/test")));
when(request.getMethod()).thenReturn("GET");
when(request.getHeaders()).thenReturn(null);
when(request.getUriInfo()).thenReturn(extendedUriInfo);
when(request.getHeaderString(HttpHeaders.USER_AGENT)).thenReturn(USER_AGENT);
new MetricsHttpEventHandler.SetInfoRequestFilter().filter(request, mock(ContainerResponse.class));
verify(request).setProperty(
eq(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME),
eq(new MetricsHttpEventHandler.RequestInfo("/test", "GET", USER_AGENT)));
}
@Test
void testResponseFilterModifiesRequestInfo() {
final MetricsHttpEventHandler.RequestInfo requestInfo =
new MetricsHttpEventHandler.RequestInfo("unknown", "POST", USER_AGENT);
final ContainerRequest request = mock(ContainerRequest.class);
when(request.getProperty(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME)).thenReturn(requestInfo);
final ExtendedUriInfo extendedUriInfo = mock(ExtendedUriInfo.class);
when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate("/test")));
when(request.getUriInfo()).thenReturn(extendedUriInfo);
new MetricsHttpEventHandler.SetInfoRequestFilter().filter(request, mock(ContainerResponse.class));
assertEquals(new MetricsHttpEventHandler.RequestInfo("/test", "POST", USER_AGENT), requestInfo);
}
}

View File

@ -36,10 +36,9 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.eclipse.jetty.websocket.api.Callback;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.eclipse.jetty.websocket.api.WriteCallback;
import org.glassfish.jersey.server.ApplicationHandler;
import org.glassfish.jersey.server.ContainerRequest;
import org.glassfish.jersey.server.ContainerResponse;
@ -175,11 +174,9 @@ class MetricsRequestEventListenerTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
final Session session = mock(Session.class);
final RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
final UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
when(request.getHeader(HttpHeaders.USER_AGENT)).thenReturn("Signal-Android/4.53.7 (Android 8.1)");
when(request.getHeaders()).thenReturn(Map.of(HttpHeaders.USER_AGENT, List.of("Signal-Android/4.53.7 (Android 8.1)")));
@ -191,15 +188,15 @@ class MetricsRequestEventListenerTest {
when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class)))
.thenReturn(requestBytesCounter);
provider.onWebSocketConnect(session);
provider.onWebSocketOpen(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello",
new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP);
final ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
@ -244,11 +241,9 @@ class MetricsRequestEventListenerTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
final Session session = mock(Session.class);
final RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
final UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_COUNTER_NAME), any(Iterable.class))).thenReturn(
@ -258,15 +253,15 @@ class MetricsRequestEventListenerTest {
when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class)))
.thenReturn(requestBytesCounter);
provider.onWebSocketConnect(session);
provider.onWebSocketOpen(session);
final byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello",
new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP);
final ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
@ -314,11 +309,9 @@ class MetricsRequestEventListenerTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
final Session session = mock(Session.class);
final RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
final UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_COUNTER_NAME), any(Iterable.class))).thenReturn(
@ -328,15 +321,15 @@ class MetricsRequestEventListenerTest {
when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class)))
.thenReturn(requestBytesCounter);
provider.onWebSocketConnect(session);
provider.onWebSocketOpen(session);
final byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello",
new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP);
final ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);

View File

@ -149,15 +149,13 @@ class TlsCertificateExpirationUtilTest {
@Test
void test() throws Exception {
try (Resource keystore = TestResource.fromBase64Mime("keystore", KEYSTORE_BASE64)) {
final Resource keystore = TestResource.fromBase64Mime("keystore", KEYSTORE_BASE64);
final KeyStore keyStore = CertificateUtils.getKeyStore(keystore, "PKCS12", null, KEYSTORE_PASSWORD);
final KeyStore keyStore = CertificateUtils.getKeyStore(keystore, "PKCS12", null, KEYSTORE_PASSWORD);
final Map<String, Instant> expected = Map.of(
"localhost:EdDSA", EDDSA_EXPIRATION,
"localhost:RSA", RSA_EXPIRATION);
assertEquals(expected, TlsCertificateExpirationUtil.getIdentifiersAndExpirations(keyStore, KEYSTORE_PASSWORD));
}
final Map<String, Instant> expected = Map.of(
"localhost:EdDSA", EDDSA_EXPIRATION,
"localhost:RSA", RSA_EXPIRATION);
assertEquals(expected, TlsCertificateExpirationUtil.getIdentifiersAndExpirations(keyStore, KEYSTORE_PASSWORD));
}
}

View File

@ -419,13 +419,6 @@ class FaultTolerantRedisClusterClientTest {
return this;
}
@Override
@Deprecated
public ClientResources.Builder dnsResolver(final DnsResolver dnsResolver) {
delegate.dnsResolver(dnsResolver);
return this;
}
@Override
public ClientResources.Builder eventBus(final EventBus eventBus) {
delegate.eventBus(eventBus);

View File

@ -4,14 +4,6 @@
*/
package org.whispersystems.textsecuregcm.tests.util;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.whispersystems.websocket.messages.WebSocketMessage;
import org.whispersystems.websocket.messages.WebSocketMessageFactory;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Objects;
@ -19,8 +11,14 @@ import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import org.eclipse.jetty.websocket.api.Callback;
import org.eclipse.jetty.websocket.api.Session;
import org.whispersystems.websocket.messages.WebSocketMessage;
import org.whispersystems.websocket.messages.WebSocketMessageFactory;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
public class TestWebsocketListener implements WebSocketListener {
public class TestWebsocketListener implements Session.Listener.AutoDemanding {
private final AtomicLong requestId = new AtomicLong();
private final CompletableFuture<Session> started = new CompletableFuture<>();
@ -34,7 +32,7 @@ public class TestWebsocketListener implements WebSocketListener {
@Override
public void onWebSocketConnect(final Session session) {
public void onWebSocketOpen(final Session session) {
started.complete(session);
}
@ -63,19 +61,15 @@ public class TestWebsocketListener implements WebSocketListener {
responseFutures.put(id, future);
final byte[] requestBytes = messageFactory.createRequest(
Optional.of(id), verb, requestPath, headers, body).toByteArray();
try {
session.getRemote().sendBytes(ByteBuffer.wrap(requestBytes));
} catch (IOException e) {
throw new RuntimeException(e);
}
session.sendBinary(ByteBuffer.wrap(requestBytes), Callback.NOOP);
return future;
});
}
@Override
public void onWebSocketBinary(final byte[] payload, final int offset, final int length) {
public void onWebSocketBinary(final ByteBuffer payload, final Callback callback) {
try {
WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length);
WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload);
if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.RESPONSE_MESSAGE) {
responseFutures.get(webSocketMessage.getResponseMessage().getRequestId())
.complete(webSocketMessage.getResponseMessage());

View File

@ -6,12 +6,9 @@
package org.whispersystems.textsecuregcm.util.jetty;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.net.MalformedURLException;
import java.net.URI;
import java.nio.channels.ReadableByteChannel;
import java.nio.file.Path;
import java.util.Base64;
import org.eclipse.jetty.util.resource.Resource;
@ -30,15 +27,16 @@ public class TestResource extends Resource {
}
@Override
public boolean isContainedIn(final Resource r) throws MalformedURLException {
return false;
public Path getPath() {
return null;
}
@Override
public void close() {
public InputStream newInputStream() {
return new ByteArrayInputStream(data);
}
@Override
public boolean exists() {
return true;
@ -50,13 +48,13 @@ public class TestResource extends Resource {
}
@Override
public long lastModified() {
return 0;
public boolean isReadable() {
return true;
}
@Override
public long length() {
return 0;
return data.length;
}
@Override
@ -64,43 +62,19 @@ public class TestResource extends Resource {
return null;
}
@Override
public File getFile() throws IOException {
return null;
}
@Override
public String getName() {
return name;
}
@Override
public InputStream getInputStream() throws IOException {
return new ByteArrayInputStream(data);
public String getFileName() {
return "";
}
@Override
public ReadableByteChannel getReadableByteChannel() throws IOException {
public Resource resolve(final String subUriPath) {
return null;
}
@Override
public boolean delete() throws SecurityException {
return false;
}
@Override
public boolean renameTo(final Resource dest) throws SecurityException {
return false;
}
@Override
public String[] list() {
return new String[]{name};
}
@Override
public Resource addPath(final String path) throws IOException, MalformedURLException {
return this;
}
}

View File

@ -39,10 +39,9 @@ import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.stream.Stream;
import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.eclipse.jetty.websocket.api.Callback;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.eclipse.jetty.websocket.api.WriteCallback;
import org.glassfish.jersey.server.ApplicationHandler;
import org.glassfish.jersey.server.ResourceConfig;
import org.glassfish.jersey.server.ServerProperties;
@ -143,12 +142,12 @@ class LoggingUnhandledExceptionMapperTest {
WebSocketResourceProvider<TestPrincipal> provider = createWebsocketProvider(userAgentHeader, session,
responseFuture::complete);
provider.onWebSocketConnect(session);
provider.onWebSocketOpen(session);
byte[] message = new ProtobufWebSocketMessageFactory()
.createRequest(Optional.of(111L), "GET", targetPath, new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP);
responseFuture.get(1, TimeUnit.SECONDS);
@ -179,15 +178,13 @@ class LoggingUnhandledExceptionMapperTest {
TestPrincipal.authenticatedTestPrincipal("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
doAnswer(answer -> {
responseHandler.accept(answer.getArgument(0, ByteBuffer.class));
return null;
}).when(remoteEndpoint).sendBytes(any(), any(WriteCallback.class));
}).when(session).sendBinary(any(ByteBuffer.class), any(Callback.class));
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
when(request.getHeader(HttpHeaders.USER_AGENT)).thenReturn(userAgentHeader);
when(request.getHeaders()).thenReturn(Map.of(HttpHeaders.USER_AGENT, List.of(userAgentHeader)));

View File

@ -19,7 +19,7 @@ import java.util.Optional;
import java.util.UUID;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
@ -44,7 +44,7 @@ class WebSocketAccountAuthenticatorTest {
private AccountAuthenticator accountAuthenticator;
private UpgradeRequest upgradeRequest;
private JettyServerUpgradeRequest upgradeRequest;
@BeforeEach
void setUp() {
@ -56,7 +56,7 @@ class WebSocketAccountAuthenticatorTest {
when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD))))
.thenReturn(Optional.empty());
upgradeRequest = mock(UpgradeRequest.class);
upgradeRequest = mock(JettyServerUpgradeRequest.class);
}
@ParameterizedTest

View File

@ -337,6 +337,8 @@ dynamicConfig:
object: |
captcha:
scoreFloor: 1.0
grpcAllowList:
enableAll: true
remoteConfig:
globalConfig: # keys and values that are given to clients on GET /v1/config
@ -521,6 +523,8 @@ idlePrimaryDeviceReminder:
grpc:
port: 50051
websocketPort: 8080
h2c: true
asnTable:
s3Region: a-region
@ -536,3 +540,18 @@ callQualitySurvey:
hlrLookup:
apiKey: secret://hlrLookup.apiKey
apiSecret: secret://hlrLookup.apiSecret
server:
allowedMethods:
- GET
- POST
- PUT
- DELETE
- HEAD
- OPTIONS
- PATCH
- CONNECT
applicationConnectors:
- type: h2c
port: 8080
useForwardedHeaders: true

View File

@ -0,0 +1,19 @@
#!/bin/sh
# Generates self-signed local testing certificates for OmnibusH2ServerTest
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
WORK_DIR="$(mktemp -d)"
trap 'rm -rf "$WORK_DIR"' EXIT
PASSWORD="password"
DAYS=36500
OMNIBUS_KS="$SCRIPT_DIR/omnibus-h2-server-test-keystore.p12"
openssl genpkey -algorithm RSA -pkeyopt rsa_keygen_bits:2048 -out "$WORK_DIR/foo-rsa.key" 2>/dev/null;
openssl req -new -x509 -key "$WORK_DIR/foo-rsa.key" -out "$WORK_DIR/foo-rsa.crt" -days "$DAYS" -subj "/CN=foo.example.com" -addext "subjectAltName=DNS:foo.example.com"
openssl pkcs12 -export -in "$WORK_DIR/foo-rsa.crt" -inkey "$WORK_DIR/foo-rsa.key" -out "$WORK_DIR/foo-rsa.p12" -name foo -passout pass:$PASSWORD
keytool -importkeystore -noprompt -srckeystore "$WORK_DIR/foo-rsa.p12" -srcstoretype PKCS12 -srcstorepass $PASSWORD -destkeystore "$OMNIBUS_KS" -deststoretype PKCS12 -deststorepass $PASSWORD
echo "Wrote keystore to $OMNIBUS_KS"

View File

@ -0,0 +1,36 @@
#!/bin/sh
# Generates self-signed local testing certificates for SniMapperTest
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
WORK_DIR="$(mktemp -d)"
trap 'rm -rf "$WORK_DIR"' EXIT
PASSWORD="password"
DAYS=36500
SNI_KS="$SCRIPT_DIR/sni-mapper-test-keystore.p12"
openssl genpkey -algorithm RSA -pkeyopt rsa_keygen_bits:2048 -out "$WORK_DIR/foo-rsa.key" 2>/dev/null;
openssl req -new -x509 -key "$WORK_DIR/foo-rsa.key" -out "$WORK_DIR/foo-rsa.crt" -days "$DAYS" -subj "/CN=foo.example.com" -addext "subjectAltName=DNS:foo.example.com"
openssl genpkey -algorithm Ed25519 -out "$WORK_DIR/foo-ed25519.key" 2>/dev/null;
openssl req -new -x509 -key "$WORK_DIR/foo-ed25519.key" -out "$WORK_DIR/foo-ed25519.crt" -days "$DAYS" -subj "/CN=foo.example.com" -addext "subjectAltName=DNS:foo.example.com"
openssl genpkey -algorithm RSA -pkeyopt rsa_keygen_bits:2048 -out "$WORK_DIR/bar-rsa.key" 2>/dev/null;
openssl req -new -x509 -key "$WORK_DIR/bar-rsa.key" -out "$WORK_DIR/bar-rsa.crt" -days "$DAYS" -subj "/CN=bar.example.com" -addext "subjectAltName=DNS:bar.example.com"
openssl genpkey -algorithm Ed25519 -out "$WORK_DIR/bar-ed25519.key" 2>/dev/null;
openssl req -new -x509 -key "$WORK_DIR/bar-ed25519.key" -out "$WORK_DIR/bar-ed25519.crt" -days "$DAYS" -subj "/CN=BAR.EXAMPLE.COM" -addext "subjectAltName=DNS:BAR.EXAMPLE.COM"
openssl pkcs12 -export -in "$WORK_DIR/foo-rsa.crt" -inkey "$WORK_DIR/foo-rsa.key" -out "$WORK_DIR/foo-rsa.p12" -name foo -passout pass:$PASSWORD
openssl pkcs12 -export -in "$WORK_DIR/foo-ed25519.crt" -inkey "$WORK_DIR/foo-ed25519.key" -out "$WORK_DIR/foo-ed25519.p12" -name foo-ed25519 -passout pass:$PASSWORD
openssl pkcs12 -export -in "$WORK_DIR/bar-rsa.crt" -inkey "$WORK_DIR/bar-rsa.key" -out "$WORK_DIR/bar-rsa.p12" -name bar -passout pass:$PASSWORD
openssl pkcs12 -export -in "$WORK_DIR/bar-ed25519.crt" -inkey "$WORK_DIR/bar-ed25519.key" -out "$WORK_DIR/bar-ed25519.p12" -name bar-ed25519 -passout pass:$PASSWORD
keytool -importkeystore -noprompt -srckeystore "$WORK_DIR/foo-ed25519.p12" -srcstoretype PKCS12 -srcstorepass $PASSWORD -destkeystore "$SNI_KS" -deststoretype PKCS12 -deststorepass $PASSWORD
keytool -importkeystore -noprompt -srckeystore "$WORK_DIR/foo-rsa.p12" -srcstoretype PKCS12 -srcstorepass $PASSWORD -destkeystore "$SNI_KS" -deststoretype PKCS12 -deststorepass $PASSWORD
keytool -importkeystore -noprompt -srckeystore "$WORK_DIR/bar-ed25519.p12" -srcstoretype PKCS12 -srcstorepass $PASSWORD -destkeystore "$SNI_KS" -deststoretype PKCS12 -deststorepass $PASSWORD
keytool -importkeystore -noprompt -srckeystore "$WORK_DIR/bar-rsa.p12" -srcstoretype PKCS12 -srcstorepass $PASSWORD -destkeystore "$SNI_KS" -deststoretype PKCS12 -deststorepass $PASSWORD
echo "Wrote 4 certificates for two SNIs to $SNI_KS"

View File

@ -13,15 +13,19 @@
<dependencies>
<dependency>
<groupId>org.eclipse.jetty.websocket</groupId>
<artifactId>websocket-jetty-api</artifactId>
<artifactId>jetty-websocket-jetty-api</artifactId>
</dependency>
<dependency>
<groupId>org.eclipse.jetty.websocket</groupId>
<artifactId>websocket-jetty-server</artifactId>
<artifactId>jetty-websocket-jetty-server</artifactId>
</dependency>
<dependency>
<groupId>org.eclipse.jetty.websocket</groupId>
<artifactId>websocket-servlet</artifactId>
<groupId>org.eclipse.jetty.ee10.websocket</groupId>
<artifactId>jetty-ee10-websocket-jetty-server</artifactId>
</dependency>
<dependency>
<groupId>org.eclipse.jetty.ee10.websocket</groupId>
<artifactId>jetty-ee10-websocket-servlet</artifactId>
</dependency>
<dependency>
<groupId>io.dropwizard</groupId>

View File

@ -13,9 +13,8 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.eclipse.jetty.websocket.api.Callback;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WriteCallback;
import org.eclipse.jetty.websocket.api.exceptions.WebSocketException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -30,15 +29,13 @@ public class WebSocketClient {
private static final SecureRandom SECURE_RANDOM = new SecureRandom();
private final Session session;
private final RemoteEndpoint remoteEndpoint;
private final WebSocketMessageFactory messageFactory;
private final Map<Long, CompletableFuture<WebSocketResponseMessage>> pendingRequestMapper;
private final Instant created;
public WebSocketClient(Session session, RemoteEndpoint remoteEndpoint, WebSocketMessageFactory messageFactory,
public WebSocketClient(Session session, WebSocketMessageFactory messageFactory,
Map<Long, CompletableFuture<WebSocketResponseMessage>> pendingRequestMapper) {
this.session = session;
this.remoteEndpoint = remoteEndpoint;
this.messageFactory = messageFactory;
this.pendingRequestMapper = pendingRequestMapper;
this.created = Instant.now();
@ -56,9 +53,9 @@ public class WebSocketClient {
WebSocketMessage requestMessage = messageFactory.createRequest(Optional.of(requestId), verb, path, headers, body);
try {
remoteEndpoint.sendBytes(ByteBuffer.wrap(requestMessage.toByteArray()), new WriteCallback() {
session.sendBinary(ByteBuffer.wrap(requestMessage.toByteArray()), new Callback() {
@Override
public void writeFailed(Throwable x) {
public void fail(Throwable x) {
logger.debug("Write failed", x);
pendingRequestMapper.remove(requestId);
future.completeExceptionally(x);
@ -86,9 +83,9 @@ public class WebSocketClient {
}
public void close(final int code, final String message) {
session.close(code, message, new WriteCallback() {
session.close(code, message, new Callback() {
@Override
public void writeFailed(final Throwable throwable) {
public void fail(final Throwable throwable) {
try {
session.disconnect();
} catch (final Exception e) {
@ -108,6 +105,6 @@ public class WebSocketClient {
}
public SocketAddress getRemoteAddress() {
return session.getRemoteAddress();
return session.getRemoteSocketAddress();
}
}

View File

@ -16,7 +16,6 @@ import java.net.URI;
import java.nio.ByteBuffer;
import java.security.Principal;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
@ -26,12 +25,8 @@ import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.eclipse.jetty.websocket.api.Callback;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.eclipse.jetty.websocket.api.WriteCallback;
import org.eclipse.jetty.websocket.api.exceptions.MessageTooLargeException;
import org.glassfish.jersey.internal.MapPropertiesDelegate;
import org.glassfish.jersey.server.ApplicationHandler;
@ -51,11 +46,11 @@ import org.whispersystems.websocket.setup.WebSocketConnectListener;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class WebSocketResourceProvider<T extends Principal> implements WebSocketListener {
public class WebSocketResourceProvider<T extends Principal> implements Session.Listener.AutoDemanding {
/**
* A static exception instance passed to outstanding requests (via {@code completeExceptionally} in
* {@link #onWebSocketClose(int, String)}
* {@link #onWebSocketClose}
*/
public static final IOException CONNECTION_CLOSED_EXCEPTION = new IOException("Connection closed!");
private static final Logger logger = LoggerFactory.getLogger(WebSocketResourceProvider.class);
@ -73,7 +68,6 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
private final int localPort;
private Session session;
private RemoteEndpoint remoteEndpoint;
private WebSocketSessionContext context;
private static final Set<String> EXCLUDED_UPGRADE_REQUEST_HEADERS = Set.of("connection", "upgrade");
@ -99,11 +93,10 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
}
@Override
public void onWebSocketConnect(Session session) {
public void onWebSocketOpen(Session session) {
this.session = session;
this.remoteEndpoint = session.getRemote();
this.context = new WebSocketSessionContext(
new WebSocketClient(session, remoteEndpoint, messageFactory, requestMap));
new WebSocketClient(session, messageFactory, requestMap));
this.context.setAuthenticated(reusableAuth.orElse(null));
this.session.setIdleTimeout(idleTimeout);
@ -128,9 +121,19 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
}
@Override
public void onWebSocketBinary(byte[] payload, int offset, int length) {
public void onWebSocketBinary(final ByteBuffer payload, final Callback callback) {
try {
WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length);
onWebSocketBinary(payload);
callback.succeed();
} catch (RuntimeException e) {
callback.fail(e);
throw e;
}
}
private void onWebSocketBinary(ByteBuffer payload) {
try {
final WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload);
switch (webSocketMessage.getType()) {
case REQUEST_MESSAGE:
@ -150,17 +153,21 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
}
@Override
public void onWebSocketClose(int statusCode, String reason) {
if (context != null) {
context.notifyClosed(statusCode, reason);
public void onWebSocketClose(int statusCode, String reason, Callback callback) {
try {
if (context != null) {
context.notifyClosed(statusCode, reason);
for (long requestId : requestMap.keySet()) {
CompletableFuture<WebSocketResponseMessage> outstandingRequest = requestMap.remove(requestId);
for (long requestId : requestMap.keySet()) {
CompletableFuture<WebSocketResponseMessage> outstandingRequest = requestMap.remove(requestId);
if (outstandingRequest != null) {
outstandingRequest.completeExceptionally(CONNECTION_CLOSED_EXCEPTION);
if (outstandingRequest != null) {
outstandingRequest.completeExceptionally(CONNECTION_CLOSED_EXCEPTION);
}
}
}
} finally {
callback.succeed();
}
}
@ -287,7 +294,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
}
private void close(Session session, int status, String message) {
session.close(status, message);
session.close(status, message, Callback.NOOP);
}
private void sendResponse(WebSocketRequestMessage requestMessage, ContainerResponse response,
@ -306,7 +313,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
Optional.ofNullable(body))
.toByteArray();
remoteEndpoint.sendBytes(ByteBuffer.wrap(responseBytes), WriteCallback.NOOP);
session.sendBinary(ByteBuffer.wrap(responseBytes), Callback.NOOP);
}
}
@ -318,7 +325,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
getHeaderList(error.getStringHeaders()),
Optional.empty());
remoteEndpoint.sendBytes(ByteBuffer.wrap(response.toByteArray()), WriteCallback.NOOP);
session.sendBinary(ByteBuffer.wrap(response.toByteArray()), Callback.NOOP);
}
}

View File

@ -13,11 +13,9 @@ import java.security.Principal;
import java.util.Map;
import java.util.Optional;
import org.apache.commons.lang3.StringUtils;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse;
import org.eclipse.jetty.websocket.server.JettyWebSocketCreator;
import org.eclipse.jetty.websocket.server.JettyWebSocketServlet;
import org.eclipse.jetty.websocket.server.JettyWebSocketServletFactory;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeResponse;
import org.eclipse.jetty.ee10.websocket.server.JettyWebSocketCreator;
import org.glassfish.jersey.CommonProperties;
import org.glassfish.jersey.server.ApplicationHandler;
import org.slf4j.Logger;
@ -25,23 +23,20 @@ import org.slf4j.LoggerFactory;
import org.whispersystems.websocket.auth.InvalidCredentialsException;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.session.WebSocketSessionContextValueFactoryProvider;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
public class WebSocketResourceProviderFactory<T extends Principal> extends JettyWebSocketServlet implements
JettyWebSocketCreator {
public class WebSocketResourceProviderFactory<T extends Principal> implements JettyWebSocketCreator {
private static final Logger logger = LoggerFactory.getLogger(WebSocketResourceProviderFactory.class);
private final WebSocketEnvironment<T> environment;
private final ApplicationHandler jerseyApplicationHandler;
private final WebSocketConfiguration configuration;
private final String remoteAddressPropertyName;
public WebSocketResourceProviderFactory(WebSocketEnvironment<T> environment, Class<T> principalClass,
WebSocketConfiguration configuration, String remoteAddressPropertyName) {
String remoteAddressPropertyName) {
this.environment = environment;
environment.jersey().register(new WebSocketSessionContextValueFactoryProvider.Binder());
@ -55,7 +50,6 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends Jetty
this.jerseyApplicationHandler = new ApplicationHandler(environment.jersey());
this.configuration = configuration;
this.remoteAddressPropertyName = remoteAddressPropertyName;
}
@ -91,6 +85,7 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends Jetty
// Authentication may fail for non-incorrect-credential reasons (e.g. we couldn't read from the account database).
// If that happens, we don't want to incorrectly tell clients that they provided bad credentials.
logger.warn("Authentication failure", e);
try {
response.sendError(500, "Failure");
} catch (final IOException ignored) {
@ -99,13 +94,6 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends Jetty
}
}
@Override
public void configure(JettyWebSocketServletFactory factory) {
factory.setCreator(this);
factory.setMaxBinaryMessageSize(configuration.getMaxBinaryMessageSize());
factory.setMaxTextMessageSize(configuration.getMaxTextMessageSize());
}
private String getRemoteAddress(JettyServerUpgradeRequest request) {
final String remoteAddress = (String) request.getHttpServletRequest().getAttribute(remoteAddressPropertyName);
if (StringUtils.isBlank(remoteAddress)) {

View File

@ -7,8 +7,8 @@ package org.whispersystems.websocket.auth;
import java.security.Principal;
import java.util.Optional;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeResponse;
public interface AuthenticatedWebSocketUpgradeFilter<T extends Principal> {

View File

@ -6,7 +6,7 @@ package org.whispersystems.websocket.auth;
import java.security.Principal;
import java.util.Optional;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest;
public interface WebSocketAuthenticator<T extends Principal> {
@ -20,5 +20,5 @@ public interface WebSocketAuthenticator<T extends Principal> {
*
* @throws InvalidCredentialsException if credentials were provided, but could not be authenticated
*/
Optional<T> authenticate(UpgradeRequest request) throws InvalidCredentialsException;
Optional<T> authenticate(JettyServerUpgradeRequest request) throws InvalidCredentialsException;
}

View File

@ -5,22 +5,23 @@
package org.whispersystems.websocket.messages;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Optional;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public interface WebSocketMessageFactory {
public WebSocketMessage parseMessage(byte[] serialized, int offset, int len)
WebSocketMessage parseMessage(ByteBuffer serialized)
throws InvalidMessageException;
public WebSocketMessage createRequest(Optional<Long> requestId,
String verb, String path,
List<String> headers,
Optional<byte[]> body);
WebSocketMessage createRequest(Optional<Long> requestId,
String verb, String path,
List<String> headers,
Optional<byte[]> body);
public WebSocketMessage createResponse(long requestId, int status, String message,
List<String> headers,
Optional<byte[]> body);
WebSocketMessage createResponse(long requestId, int status, String message,
List<String> headers,
Optional<byte[]> body);
}

View File

@ -10,14 +10,15 @@ import org.whispersystems.websocket.messages.InvalidMessageException;
import org.whispersystems.websocket.messages.WebSocketMessage;
import org.whispersystems.websocket.messages.WebSocketRequestMessage;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import java.nio.ByteBuffer;
public class ProtobufWebSocketMessage implements WebSocketMessage {
private final SubProtocol.WebSocketMessage message;
ProtobufWebSocketMessage(byte[] buffer, int offset, int length) throws InvalidMessageException {
ProtobufWebSocketMessage(ByteBuffer buffer) throws InvalidMessageException {
try {
this.message = SubProtocol.WebSocketMessage.parseFrom(ByteString.copyFrom(buffer, offset, length));
this.message = SubProtocol.WebSocketMessage.parseFrom(ByteString.copyFrom(buffer));
if (getType() == Type.REQUEST_MESSAGE) {
if (!message.getRequest().hasVerb() || !message.getRequest().hasPath()) {

View File

@ -9,16 +9,17 @@ import org.whispersystems.websocket.messages.InvalidMessageException;
import org.whispersystems.websocket.messages.WebSocketMessage;
import org.whispersystems.websocket.messages.WebSocketMessageFactory;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Optional;
public class ProtobufWebSocketMessageFactory implements WebSocketMessageFactory {
@Override
public WebSocketMessage parseMessage(byte[] serialized, int offset, int len)
public WebSocketMessage parseMessage(ByteBuffer serialized)
throws InvalidMessageException
{
return new ProtobufWebSocketMessage(serialized, offset, len);
return new ProtobufWebSocketMessage(serialized);
}
@Override

View File

@ -19,17 +19,15 @@ import java.io.IOException;
import java.security.Principal;
import java.util.Optional;
import javax.security.auth.Subject;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeResponse;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse;
import org.eclipse.jetty.websocket.server.JettyWebSocketServletFactory;
import org.glassfish.jersey.server.ResourceConfig;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.websocket.auth.InvalidCredentialsException;
import org.whispersystems.websocket.auth.AuthenticatedWebSocketUpgradeFilter;
import org.whispersystems.websocket.auth.InvalidCredentialsException;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
public class WebSocketResourceProviderFactoryTest {
@ -60,8 +58,7 @@ public class WebSocketResourceProviderFactoryTest {
when(authenticator.authenticate(eq(request))).thenThrow(new InvalidCredentialsException());
when(environment.jersey()).thenReturn(jerseyEnvironment);
WebSocketResourceProviderFactory<?> factory = new WebSocketResourceProviderFactory<>(environment, Account.class,
mock(WebSocketConfiguration.class), REMOTE_ADDRESS_PROPERTY_NAME);
WebSocketResourceProviderFactory<?> factory = new WebSocketResourceProviderFactory<>(environment, Account.class, REMOTE_ADDRESS_PROPERTY_NAME);
Object connection = factory.createWebSocket(request, response);
assertNull(connection);
@ -80,16 +77,15 @@ public class WebSocketResourceProviderFactoryTest {
final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class);
when(httpServletRequest.getAttribute(REMOTE_ADDRESS_PROPERTY_NAME)).thenReturn("127.0.0.1");
when(request.getHttpServletRequest()).thenReturn(httpServletRequest);
WebSocketResourceProviderFactory<?> factory = new WebSocketResourceProviderFactory<>(environment, Account.class,
mock(WebSocketConfiguration.class), REMOTE_ADDRESS_PROPERTY_NAME);
REMOTE_ADDRESS_PROPERTY_NAME);
Object connection = factory.createWebSocket(request, response);
assertNotNull(connection);
verifyNoMoreInteractions(response);
verify(authenticator).authenticate(eq(request));
((WebSocketResourceProvider<?>) connection).onWebSocketConnect(mock(Session.class));
((WebSocketResourceProvider<?>) connection).onWebSocketOpen(mock(Session.class));
assertNotNull(((WebSocketResourceProvider<?>) connection).getContext().getAuthenticated());
assertEquals(((WebSocketResourceProvider<?>) connection).getContext().getAuthenticated(), account);
@ -103,7 +99,6 @@ public class WebSocketResourceProviderFactoryTest {
WebSocketResourceProviderFactory<Account> factory = new WebSocketResourceProviderFactory<>(environment,
Account.class,
mock(WebSocketConfiguration.class),
REMOTE_ADDRESS_PROPERTY_NAME);
Object connection = factory.createWebSocket(request, response);
@ -112,20 +107,6 @@ public class WebSocketResourceProviderFactoryTest {
verify(authenticator).authenticate(eq(request));
}
@Test
void testConfigure() {
JettyWebSocketServletFactory servletFactory = mock(JettyWebSocketServletFactory.class);
when(environment.jersey()).thenReturn(jerseyEnvironment);
WebSocketResourceProviderFactory<Account> factory = new WebSocketResourceProviderFactory<>(environment,
Account.class,
mock(WebSocketConfiguration.class),
REMOTE_ADDRESS_PROPERTY_NAME);
factory.configure(servletFactory);
verify(servletFactory).setCreator(eq(factory));
}
@Test
void testAuthenticatedWebSocketUpgradeFilter() throws InvalidCredentialsException {
final Account account = new Account();
@ -137,12 +118,11 @@ public class WebSocketResourceProviderFactoryTest {
final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class);
when(httpServletRequest.getAttribute(REMOTE_ADDRESS_PROPERTY_NAME)).thenReturn("127.0.0.1");
when(request.getHttpServletRequest()).thenReturn(httpServletRequest);
final AuthenticatedWebSocketUpgradeFilter<Account> filter = mock(AuthenticatedWebSocketUpgradeFilter.class);
when(environment.getAuthenticatedWebSocketUpgradeFilter()).thenReturn(filter);
final WebSocketResourceProviderFactory<?> factory = new WebSocketResourceProviderFactory<>(environment, Account.class,
mock(WebSocketConfiguration.class), REMOTE_ADDRESS_PROPERTY_NAME);
REMOTE_ADDRESS_PROPERTY_NAME);
assertNotNull(factory.createWebSocket(request, response));
verify(filter).handleAuthentication(reusableAuth, request, response);

View File

@ -47,11 +47,9 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import org.eclipse.jetty.websocket.api.CloseStatus;
import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.eclipse.jetty.websocket.api.Callback;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.eclipse.jetty.websocket.api.WriteCallback;
import org.glassfish.jersey.server.ApplicationHandler;
import org.glassfish.jersey.server.ContainerRequest;
import org.glassfish.jersey.server.ContainerResponse;
@ -92,11 +90,10 @@ class WebSocketResourceProviderTest {
when(session.getUpgradeRequest()).thenReturn(request);
provider.onWebSocketConnect(session);
provider.onWebSocketOpen(session);
verify(session, never()).close(anyInt(), anyString());
verify(session, never()).close(anyInt(), anyString(), any(Callback.class));
verify(session, never()).close();
verify(session, never()).close(any(CloseStatus.class));
ArgumentCaptor<WebSocketSessionContext> contextArgumentCaptor = ArgumentCaptor.forClass(
WebSocketSessionContext.class);
@ -114,11 +111,9 @@ class WebSocketResourceProviderTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
ContainerResponse response = mock(ContainerResponse.class);
when(response.getStatus()).thenReturn(200);
@ -148,16 +143,15 @@ class WebSocketResourceProviderTest {
return CompletableFuture.completedFuture(response);
});
provider.onWebSocketConnect(session);
provider.onWebSocketOpen(session);
verify(session, never()).close(anyInt(), anyString());
verify(session, never()).close(anyInt(), anyString(), any(Callback.class));
verify(session, never()).close();
verify(session, never()).close(any(CloseStatus.class));
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/bar",
new LinkedList<>(), Optional.of("hello world!".getBytes())).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP);
ArgumentCaptor<ContainerRequest> requestCaptor = ArgumentCaptor.forClass(ContainerRequest.class);
ArgumentCaptor<ByteBuffer> responseCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
@ -171,7 +165,7 @@ class WebSocketResourceProviderTest {
assertThat(bundledRequest.getPath(false)).isEqualTo("bar");
verify(requestLog).log(eq("127.0.0.1"), eq(bundledRequest), eq(response));
verify(remoteEndpoint).sendBytes(responseCaptor.capture(), any(WriteCallback.class));
verify(session).sendBinary(responseCaptor.capture(), any(Callback.class));
SubProtocol.WebSocketMessage responseMessageContainer = SubProtocol.WebSocketMessage.parseFrom(
responseCaptor.getValue().array());
@ -191,25 +185,22 @@ class WebSocketResourceProviderTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
when(applicationHandler.apply(any(ContainerRequest.class), any(OutputStream.class))).thenReturn(
CompletableFuture.failedFuture(new IllegalStateException("foo")));
provider.onWebSocketConnect(session);
provider.onWebSocketOpen(session);
verify(session, never()).close(anyInt(), anyString());
verify(session, never()).close(anyInt(), anyString(), any(Callback.class));
verify(session, never()).close();
verify(session, never()).close(any(CloseStatus.class));
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/bar",
new LinkedList<>(), Optional.of("hello world!".getBytes())).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP);
ArgumentCaptor<ContainerRequest> requestCaptor = ArgumentCaptor.forClass(ContainerRequest.class);
@ -223,7 +214,7 @@ class WebSocketResourceProviderTest {
ArgumentCaptor<ByteBuffer> responseCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytes(responseCaptor.capture(), any(WriteCallback.class));
verify(session).sendBinary(responseCaptor.capture(), any(Callback.class));
SubProtocol.WebSocketMessage responseMessageContainer = SubProtocol.WebSocketMessage.parseFrom(
responseCaptor.getValue().array());
@ -247,22 +238,20 @@ class WebSocketResourceProviderTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketConnect(session);
provider.onWebSocketOpen(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello",
new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
@ -287,22 +276,20 @@ class WebSocketResourceProviderTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketConnect(session);
provider.onWebSocketOpen(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET",
"/v1/test/doesntexist", new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
@ -327,22 +314,20 @@ class WebSocketResourceProviderTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketConnect(session);
provider.onWebSocketOpen(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/world",
new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
@ -367,22 +352,20 @@ class WebSocketResourceProviderTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketConnect(session);
provider.onWebSocketOpen(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/world",
new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
@ -406,22 +389,20 @@ class WebSocketResourceProviderTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketConnect(session);
provider.onWebSocketOpen(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/optional",
new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
@ -446,22 +427,20 @@ class WebSocketResourceProviderTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketConnect(session);
provider.onWebSocketOpen(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/optional",
new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
@ -486,23 +465,21 @@ class WebSocketResourceProviderTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketConnect(session);
provider.onWebSocketOpen(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "PUT",
"/v1/test/some/testparam", List.of("Content-Type: application/json"),
Optional.of(new ObjectMapper().writeValueAsBytes(new TestResource.TestEntity("mykey", 1001)))).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
@ -527,23 +504,21 @@ class WebSocketResourceProviderTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketConnect(session);
provider.onWebSocketOpen(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "PUT",
"/v1/test/some/testparam", List.of("Content-Type: application/json"),
Optional.of(new ObjectMapper().writeValueAsBytes(new TestResource.TestEntity("mykey", 5)))).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
@ -569,22 +544,20 @@ class WebSocketResourceProviderTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketConnect(session);
provider.onWebSocketOpen(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET",
"/v1/test/exception/map", List.of("Content-Type: application/json"), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
@ -609,22 +582,20 @@ class WebSocketResourceProviderTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketConnect(session);
provider.onWebSocketOpen(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/keepalive",
new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP);
ArgumentCaptor<ByteBuffer> requestCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytes(requestCaptor.capture(), any(WriteCallback.class));
verify(session).sendBinary(requestCaptor.capture(), any(Callback.class));
SubProtocol.WebSocketRequestMessage requestMessage = getRequest(requestCaptor);
assertThat(requestMessage.getVerb()).isEqualTo("GET");
@ -634,11 +605,11 @@ class WebSocketResourceProviderTest {
byte[] clientResponse = new ProtobufWebSocketMessageFactory().createResponse(requestMessage.getId(), 200, "OK",
new LinkedList<>(), Optional.of("my response".getBytes())).toByteArray();
provider.onWebSocketBinary(clientResponse, 0, clientResponse.length);
provider.onWebSocketBinary(ByteBuffer.wrap(clientResponse), Callback.NOOP);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint, times(2)).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
verify(session, times(2)).sendBinary(responseBytesCaptor.capture(), any(Callback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);