diff --git a/pom.xml b/pom.xml index 6d4ee2c4c..da43a9463 100644 --- a/pom.xml +++ b/pom.xml @@ -49,7 +49,7 @@ 3.48.0 1.14.1 2.21.0 - 4.0.16 + 5.0.1 4.33.2 @@ -253,12 +253,6 @@ commons-logging 1.3.6 - - org.ow2.asm - asm - 9.9.1 - test - com.stripe stripe-java @@ -361,7 +355,7 @@ org.wiremock - wiremock + wiremock-jetty12 3.13.1 test diff --git a/service/config/sample.yml b/service/config/sample.yml index 10818ef58..9ceaafc61 100644 --- a/service/config/sample.yml +++ b/service/config/sample.yml @@ -517,6 +517,7 @@ idlePrimaryDeviceReminder: grpc: port: 50051 + websocketPort: 8080 asnTable: s3Region: a-region diff --git a/service/pom.xml b/service/pom.xml index 3b3e4db42..b5a62ada7 100644 --- a/service/pom.xml +++ b/service/pom.xml @@ -23,6 +23,7 @@ 2.22.0-alpha 4.0.0 0.30.2.RELEASE + 12.1.5 @@ -137,6 +138,12 @@ io.dropwizard dropwizard-jetty + + org.eclipse.jetty.http2 + jetty-http2-client-transport + test + ${jetty.http2-client.version} + io.dropwizard dropwizard-validation @@ -242,15 +249,15 @@ org.eclipse.jetty.websocket - websocket-jetty-api + jetty-websocket-jetty-api - org.eclipse.jetty - jetty-servlets + org.eclipse.jetty.ee10 + jetty-ee10-servlets org.eclipse.jetty.websocket - websocket-jetty-client + jetty-websocket-jetty-client test diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 2c6b2654d..059cbe0d7 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -31,21 +31,25 @@ import io.lettuce.core.metrics.MicrometerOptions; import io.lettuce.core.resource.ClientResources; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalServerChannel; +import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.nio.NioDatagramChannel; import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.ssl.SslContext; import io.netty.resolver.ResolvedAddressTypes; import io.netty.resolver.dns.DnsNameResolver; import io.netty.resolver.dns.DnsNameResolverBuilder; +import io.netty.util.Mapping; import jakarta.servlet.DispatcherType; -import jakarta.servlet.Filter; -import jakarta.servlet.ServletRegistration; import java.io.ByteArrayInputStream; import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.net.http.HttpClient; import java.nio.charset.StandardCharsets; import java.time.Clock; import java.time.Duration; -import java.util.ArrayList; import java.util.Collections; import java.util.EnumSet; import java.util.List; @@ -60,9 +64,9 @@ import java.util.concurrent.SynchronousQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.function.Function; import java.util.stream.Stream; +import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer; import org.eclipse.jetty.websocket.core.WebSocketExtensionRegistry; import org.eclipse.jetty.websocket.core.server.WebSocketServerComponents; -import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; import org.glassfish.jersey.server.ServerProperties; import org.signal.i18n.HeaderControlledResourceBundleLookup; import org.signal.libsignal.zkgroup.GenericServerSecretParams; @@ -135,10 +139,12 @@ import org.whispersystems.textsecuregcm.currency.CurrencyConversionManager; import org.whispersystems.textsecuregcm.currency.FixerClient; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.filters.ExternalRequestFilter; +import org.whispersystems.textsecuregcm.filters.PriorityFilter; import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.filters.RemoteDeprecationFilter; import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter; import org.whispersystems.textsecuregcm.filters.RestDeprecationFilter; +import org.whispersystems.textsecuregcm.filters.StripContentLengthOnConnectFilter; import org.whispersystems.textsecuregcm.filters.TimestampResponseFilter; import org.whispersystems.textsecuregcm.grpc.AccountsAnonymousGrpcService; import org.whispersystems.textsecuregcm.grpc.AccountsGrpcService; @@ -164,7 +170,10 @@ import org.whispersystems.textsecuregcm.grpc.ProfileGrpcService; import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor; import org.whispersystems.textsecuregcm.grpc.ValidatingInterceptor; import org.whispersystems.textsecuregcm.grpc.net.ManagedGrpcServer; -import org.whispersystems.textsecuregcm.grpc.net.ManagedNioEventLoopGroup; +import org.whispersystems.textsecuregcm.grpc.net.ManagedEventLoopGroup; +import org.whispersystems.textsecuregcm.grpc.net.OmnibusH2Server; +import org.whispersystems.textsecuregcm.grpc.net.OmnibusRouter; +import org.whispersystems.textsecuregcm.grpc.net.SniMapper; import org.whispersystems.textsecuregcm.jetty.JettyHttpConfigurationCustomizer; import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceClient; import org.whispersystems.textsecuregcm.limits.CardinalityEstimator; @@ -193,7 +202,7 @@ import org.whispersystems.textsecuregcm.metrics.BackupMetrics; import org.whispersystems.textsecuregcm.metrics.CallQualitySurveyManager; import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.MetricsApplicationEventListener; -import org.whispersystems.textsecuregcm.metrics.MetricsHttpChannelListener; +import org.whispersystems.textsecuregcm.metrics.MetricsHttpEventHandler; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MicrometerAwsSdkMetricPublisher; import org.whispersystems.textsecuregcm.metrics.ReportedMessageMetricsListener; @@ -314,6 +323,7 @@ import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.s3.S3AsyncClient; +import javax.annotation.Nullable; public class WhisperServerService extends Application { @@ -617,10 +627,9 @@ public class WhisperServerService extends Application dnsResolutionEventLoopGroup = new ManagedEventLoopGroup<>(new NioEventLoopGroup()); + final DnsNameResolver cloudflareDnsResolver = new DnsNameResolverBuilder(dnsResolutionEventLoopGroup.getEventLoopGroup().next()) .resolvedAddressTypes(ResolvedAddressTypes.IPV6_PREFERRED) .completeOncePreferredResolved(false) .channelType(NioDatagramChannel.class) @@ -1013,24 +1022,38 @@ public class WhisperServerService extends Application serverBuilder = - NettyServerBuilder.forAddress(new InetSocketAddress(config.getGrpc().bindAddress(), config.getGrpc().port())); + final ManagedEventLoopGroup omnibusLocalEventLoopGroup = new ManagedEventLoopGroup<>(new DefaultEventLoopGroup()); + final ManagedEventLoopGroup omnibusNioEventLoopGroup = new ManagedEventLoopGroup<>(new NioEventLoopGroup()); + final LocalAddress grpcLocalAddress = new LocalAddress("grpc"); + final ServerBuilder serverBuilder = NettyServerBuilder + .forAddress(grpcLocalAddress) + .channelType(LocalServerChannel.class) + .bossEventLoopGroup(omnibusLocalEventLoopGroup.getEventLoopGroup()) + .workerEventLoopGroup(omnibusLocalEventLoopGroup.getEventLoopGroup()); authenticatedServices.forEach(serverBuilder::addService); unauthenticatedServices.forEach(serverBuilder::addService); - final ManagedGrpcServer exposedGrpcServer = new ManagedGrpcServer(serverBuilder.build()); + final ManagedGrpcServer localGrpcServer = new ManagedGrpcServer(serverBuilder.build()); - environment.lifecycle().manage(exposedGrpcServer); + final SocketAddress websocketAddress = + new InetSocketAddress(config.getGrpc().websocketAddress(), config.getGrpc().websocketPort()); + final OmnibusRouter omnibusRouter = new OmnibusRouter(List.of( + new OmnibusRouter.OmnibusRoute("/v1/websocket", websocketAddress), + new OmnibusRouter.OmnibusRoute("/v1/provisioning", websocketAddress)), + grpcLocalAddress); + @Nullable final Mapping sniMapping = config.getGrpc().h2c() + ? null + : SniMapper.buildSniMapping(config.getTlsKeyStoreConfiguration().path(), config.getTlsKeyStoreConfiguration().password().value()); + final OmnibusH2Server omnibusH2Server = new OmnibusH2Server( + sniMapping, + omnibusNioEventLoopGroup.getEventLoopGroup(), + omnibusLocalEventLoopGroup.getEventLoopGroup(), + new InetSocketAddress(config.getGrpc().bindAddress(), config.getGrpc().port()), omnibusRouter, + config.getGrpc().idleTimeout()); - final List filters = new ArrayList<>(); - filters.add(remoteDeprecationFilter); - filters.add(new RemoteAddressFilter()); - filters.add(new TimestampResponseFilter()); - - for (Filter filter : filters) { - environment.servlets() - .addFilter(filter.getClass().getSimpleName(), filter) - .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*"); - } + environment.lifecycle().manage(omnibusLocalEventLoopGroup); + environment.lifecycle().manage(omnibusNioEventLoopGroup); + environment.lifecycle().manage(localGrpcServer); + environment.lifecycle().manage(omnibusH2Server); if (!config.getExternalRequestFilterConfiguration().paths().isEmpty()) { environment.servlets().addFilter(ExternalRequestFilter.class.getSimpleName(), @@ -1048,9 +1071,7 @@ public class WhisperServerService extends Application { - final WebSocketExtensionRegistry extensionRegistry = WebSocketServerComponents - .getWebSocketComponents(environment.getApplicationContext().getServletContext()) - .getExtensionRegistry(); - if (config.getWebSocketConfiguration().isDisablePerMessageDeflate()) { - extensionRegistry.unregister("permessage-deflate"); - } else if (config.getWebSocketConfiguration().isDisableCrossMessageOutgoingCompression()) { - extensionRegistry.unregister("permessage-deflate"); - extensionRegistry.register("permessage-deflate", NoContextTakeoverPerMessageDeflateExtension.class); - } - }); - WebSocketResourceProviderFactory webSocketServlet = new WebSocketResourceProviderFactory<>( - webSocketEnvironment, AuthenticatedDevice.class, config.getWebSocketConfiguration(), - RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME); + webSocketEnvironment, AuthenticatedDevice.class, RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME); WebSocketResourceProviderFactory provisioningServlet = new WebSocketResourceProviderFactory<>( - provisioningEnvironment, AuthenticatedDevice.class, config.getWebSocketConfiguration(), - RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME); + provisioningEnvironment, AuthenticatedDevice.class, RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME); - ServletRegistration.Dynamic websocket = environment.servlets().addServlet("WebSocket", webSocketServlet); - ServletRegistration.Dynamic provisioning = environment.servlets().addServlet("Provisioning", provisioningServlet); + JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), + (servletContext, container) -> { + container.addMapping(websocketServletPath, webSocketServlet); + container.addMapping(provisioningWebsocketServletPath, provisioningServlet); - websocket.addMapping(websocketServletPath); - websocket.setAsyncSupported(true); + PriorityFilter.ensureFilter(servletContext, new StripContentLengthOnConnectFilter()); + PriorityFilter.ensureFilter(servletContext, new TimestampResponseFilter()); + PriorityFilter.ensureFilter(servletContext, new RemoteAddressFilter()); + PriorityFilter.ensureFilter(servletContext, remoteDeprecationFilter); - provisioning.addMapping(provisioningWebsocketServletPath); - provisioning.setAsyncSupported(true); + container.setMaxBinaryMessageSize(config.getWebSocketConfiguration().getMaxBinaryMessageSize()); + container.setMaxTextMessageSize(config.getWebSocketConfiguration().getMaxTextMessageSize()); + + final WebSocketExtensionRegistry extensionRegistry = WebSocketServerComponents + .getWebSocketComponents(environment.getApplicationContext()) + .getExtensionRegistry(); + if (config.getWebSocketConfiguration().isDisablePerMessageDeflate()) { + extensionRegistry.unregister("permessage-deflate"); + } else if (config.getWebSocketConfiguration().isDisableCrossMessageOutgoingCompression()) { + extensionRegistry.unregister("permessage-deflate"); + extensionRegistry.register("permessage-deflate", NoContextTakeoverPerMessageDeflateExtension.class); + } + }); environment.admin().addTask(new SetRequestLoggingEnabledTask()); - } private void registerExceptionMappers(Environment environment, diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter.java index 38ff7d8ef..e2853497c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter.java @@ -10,10 +10,9 @@ import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Metrics; import java.time.Clock; import java.time.Duration; -import java.time.Instant; import java.util.Optional; -import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest; -import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse; +import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest; +import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeResponse; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.websocket.auth.AuthenticatedWebSocketUpgradeFilter; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/GrpcConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/GrpcConfiguration.java index fb3c08742..295879b8f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/GrpcConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/GrpcConfiguration.java @@ -5,11 +5,33 @@ package org.whispersystems.textsecuregcm.configuration; import jakarta.validation.constraints.NotNull; +import java.time.Duration; + +/// Configuration for the gRPC Server +/// +/// @param bindAddress The host to bind the omnibus server to +/// @param port The port to bind the omnibus server to +/// @param websocketAddress The address of a listening websocket server for handling legacy requests +/// @param websocketPort The port of a listening websocket server for handling legacy requests +/// @param idleTimeout The duration after which an idle connection may be disconnected +/// @param h2c If true, listen for plaintext h2c with prior-knowledge +public record GrpcConfiguration( + @NotNull String bindAddress, + @NotNull Integer port, + @NotNull String websocketAddress, + @NotNull Integer websocketPort, + @NotNull Duration idleTimeout, + boolean h2c) { -public record GrpcConfiguration(@NotNull String bindAddress, @NotNull Integer port) { public GrpcConfiguration { if (bindAddress == null || bindAddress.isEmpty()) { bindAddress = "localhost"; } + if (websocketAddress == null || websocketAddress.isEmpty()) { + websocketAddress = "localhost"; + } + if (idleTimeout == null) { + idleTimeout = Duration.ofMinutes(5); + } } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/TlsKeyStoreConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/TlsKeyStoreConfiguration.java index 1bd2c1d68..c38912de8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/TlsKeyStoreConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/TlsKeyStoreConfiguration.java @@ -7,6 +7,9 @@ package org.whispersystems.textsecuregcm.configuration; import jakarta.validation.constraints.NotNull; import org.whispersystems.textsecuregcm.configuration.secrets.SecretString; +import javax.annotation.Nullable; -public record TlsKeyStoreConfiguration(@NotNull SecretString password) { +public record TlsKeyStoreConfiguration( + @Nullable String path, + @NotNull SecretString password) { } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/filters/PriorityFilter.java b/service/src/main/java/org/whispersystems/textsecuregcm/filters/PriorityFilter.java new file mode 100644 index 000000000..03070cf00 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/filters/PriorityFilter.java @@ -0,0 +1,82 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.filters; + +import jakarta.servlet.DispatcherType; +import jakarta.servlet.Filter; +import jakarta.servlet.ServletContext; +import java.util.EnumSet; +import java.util.Objects; +import org.eclipse.jetty.ee10.servlet.FilterHolder; +import org.eclipse.jetty.ee10.servlet.FilterMapping; +import org.eclipse.jetty.ee10.servlet.ServletContextHandler; +import org.eclipse.jetty.ee10.servlet.ServletHandler; +import org.eclipse.jetty.server.handler.ContextHandler; +import org.eclipse.jetty.util.component.LifeCycle; + +public class PriorityFilter { + + private PriorityFilter() {} + + private static FilterHolder getFilter(ServletContext servletContext, final Class filterClass) { + final ContextHandler contextHandler = Objects.requireNonNull(ServletContextHandler.getServletContextHandler(servletContext)); + final ServletHandler servletHandler = contextHandler.getDescendant(ServletHandler.class); + return servletHandler.getFilter(filterClass.getName()); + } + + /** + * Ensure a filter is available on the provided ServletContext, a new filter will added if one does not already + * exist. + *

+ * If a new filter is added, it will be added before all other filters. + *

+ * Modeled after {@link org.eclipse.jetty.ee10.websocket.servlet.WebSocketUpgradeFilter#ensureFilter(ServletContext)}, + * since its use of {@link org.eclipse.jetty.ee10.servlet.ServletHandler#prependFilter(FilterHolder)} is what makes + * this necessary. + */ + public static void ensureFilter(final ServletContext servletContext, final Filter filter) { + FilterHolder existingFilter = getFilter(servletContext, filter.getClass()); + if (existingFilter != null) { + return; + } + + final ContextHandler contextHandler = ServletContextHandler.getServletContextHandler(servletContext); + final ServletHandler servletHandler = contextHandler.getDescendant(ServletHandler.class); + + final String pathSpec = "/*"; + final FilterHolder holder = new FilterHolder(filter); + holder.setName(filter.getClass().getName()); + holder.setAsyncSupported(true); + + final FilterMapping mapping = new FilterMapping(); + mapping.setFilterName(holder.getName()); + mapping.setPathSpec(pathSpec); + mapping.setDispatcherTypes(EnumSet.of(DispatcherType.REQUEST)); + + // Add as the first filter in the list. + servletHandler.prependFilter(holder); + servletHandler.prependFilterMapping(mapping); + + // If we create the filter we must also make sure it is removed if the context is stopped. + contextHandler.addEventListener(new LifeCycle.Listener() + { + @Override + public void lifeCycleStopping(LifeCycle event) + { + servletHandler.removeFilterHolder(holder); + servletHandler.removeFilterMapping(mapping); + contextHandler.removeEventListener(this); + } + + @Override + public String toString() + { + return String.format("%sCleanupListener", filter.getClass().getSimpleName()); + } + }); + + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilter.java b/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilter.java index c1ea020b6..3b8090e66 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilter.java @@ -26,10 +26,6 @@ public class RemoteAddressFilter implements Filter { public static final String REMOTE_ADDRESS_ATTRIBUTE_NAME = RemoteAddressFilter.class.getName() + ".remoteAddress"; private static final Logger logger = LoggerFactory.getLogger(RemoteAddressFilter.class); - - public RemoteAddressFilter() { - } - @Override public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain) throws ServletException, IOException { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/filters/StripContentLengthOnConnectFilter.java b/service/src/main/java/org/whispersystems/textsecuregcm/filters/StripContentLengthOnConnectFilter.java new file mode 100644 index 000000000..9265f8c46 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/filters/StripContentLengthOnConnectFilter.java @@ -0,0 +1,71 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.filters; + +import jakarta.servlet.Filter; +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; +import jakarta.servlet.http.HttpServletRequest; +import java.io.IOException; +import java.nio.ByteBuffer; +import org.eclipse.jetty.ee10.servlet.ServletContextRequest; +import org.eclipse.jetty.http.HttpFields; +import org.eclipse.jetty.http.HttpHeader; +import org.eclipse.jetty.http.HttpMethod; +import org.eclipse.jetty.http.HttpVersion; +import org.eclipse.jetty.http.MetaData; +import org.eclipse.jetty.server.HttpStream; +import org.eclipse.jetty.server.Request; +import org.eclipse.jetty.util.Callback; + +/// Our current version of jetty (12.1.5) has a bug where it includes content-length:0 on +/// CONNECT websocket upgrade requests. Providing an HTTP/2 header frame with a +/// content-length that does not match the sum of the lengths of the data frames is technically +/// a malformed HTTP/2 stream and our netty-based reverse proxy implementation rejects it. This +/// filter strips out the superfluous content-length at stream-send time. It can be removed once +/// we update to a jetty version that fixes [jetty/jetty.project#15074](https://github.com/jetty/jetty.project/issues/15074) +public class StripContentLengthOnConnectFilter implements Filter { + + @Override + public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain) + throws IOException, ServletException { + if (request instanceof HttpServletRequest hsr && + HttpVersion.HTTP_2.is(hsr.getProtocol()) && + HttpMethod.CONNECT.is(hsr.getMethod())) { + final Request coreRequest = ServletContextRequest.getServletContextRequest(hsr); + if (coreRequest != null) { + coreRequest.addHttpStreamWrapper(StripContentLengthStream::new); + } + } + chain.doFilter(request, response); + } + + private static class StripContentLengthStream extends HttpStream.Wrapper { + + StripContentLengthStream(final HttpStream wrapped) { + super(wrapped); + } + + @Override + public void send(MetaData.Request request, MetaData.Response response, boolean last, ByteBuffer content, + Callback callback) { + if (response != null && response.getStatus() == 200 && response.getHttpFields() + .contains(HttpHeader.CONTENT_LENGTH)) { + final HttpFields fieldsWithoutContentLengthHeader = + HttpFields.build(response.getHttpFields()).remove(HttpHeader.CONTENT_LENGTH); + response = new MetaData.Response( + response.getStatus(), + response.getReason(), + response.getHttpVersion(), + fieldsWithoutContentLengthHeader, + -1, + response.getTrailersSupplier()); + } + super.send(request, response, last, content, callback); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesInterceptor.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesInterceptor.java index eeeab6684..45fd52943 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesInterceptor.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesInterceptor.java @@ -48,6 +48,8 @@ public class RequestAttributesInterceptor implements ServerInterceptor { final String acceptLanguageHeader = headers.get(ACCEPT_LANG_KEY); final String xForwardedForHeader = headers.get(X_FORWARDED_FOR_KEY); + // This assumes that X-Forwarded-For has been set by a trusted intermediate proxy. For example, this may be set by + // OmnibusH2Server which itself sets X-Forwarded-For using a PPv2 header that comes from a trusted load-balancer. final Optional remoteAddress = getMostRecentProxy(xForwardedForHeader) .flatMap(mostRecentProxy -> { try { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/H2FrameProxyHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/H2FrameProxyHandler.java new file mode 100644 index 000000000..e4167a4d4 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/H2FrameProxyHandler.java @@ -0,0 +1,91 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.micrometer.core.instrument.Metrics; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.http2.Http2StreamFrame; +import io.netty.util.ReferenceCountUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; + +/// Writes all inbound H2 frames to [this#peerStream], renumbering the inbound H2 stream-id for peer H2 stream +public class H2FrameProxyHandler extends ChannelInboundHandlerAdapter { + + private static final Logger logger = LoggerFactory.getLogger(H2FrameProxyHandler.class); + private static final String WRITABILITY_CHANGED_COUNTER_NAME = MetricsUtil.name(H2FrameProxyHandler.class, "writabilityChanged"); + + private final Channel peerStream; + private final String proxyNameTag; + + // If we fail to write to the peerStream, we want to close the inbound channel. Rather than allocate a new listener + // that captures the inbound ChannelHandlerContext on every message, we capture the ChannelHandlerContext in + // handlerAdded and use it on all forwarded writes. This would not work if we attached this handler to more than + // one channel, but we already have a designated peerStream so this handler is fundamentally single-channel. + private ChannelFutureListener closeInboundOnPeerFailure = null; + + public H2FrameProxyHandler(final Channel peerStream, final String proxyNameTag) { + this.peerStream = peerStream; + this.proxyNameTag = proxyNameTag; + } + + @Override + public void handlerAdded(final ChannelHandlerContext ctx) { + closeInboundOnPeerFailure = f -> { + if (!f.isSuccess()) { + ctx.close(); + } + }; + + // When the peer stream we are forwarding to becomes unwritable/writable, stop/start reading from the source stream. + // This prevents us from reading from the source stream as fast as we can just to buffer requests for the peer + // stream. + peerStream.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelWritabilityChanged(final ChannelHandlerContext peerCtx) throws Exception { + Metrics.counter(WRITABILITY_CHANGED_COUNTER_NAME, + "isWritable", Boolean.toString(peerCtx.channel().isWritable()), + "proxy", proxyNameTag) + .increment(); + ctx.channel().config().setAutoRead(peerStream.isWritable()); + super.channelWritabilityChanged(peerCtx); + } + }); + } + + @Override + public void channelRead(final ChannelHandlerContext ctx, final Object msg) { + logger.trace("Received frame {}", msg); + if (!(msg instanceof Http2StreamFrame streamFrame)) { + logger.error("Received unexpected frame {}", msg); + ReferenceCountUtil.release(msg); + ctx.close(); + return; + } + + // Clear the stream-id on this frame, so netty will associate it with the peerStream's stream-id. The inbound + // frame has a stream-id associated with the inbound connection. This will not match the stream-id of the peer + // stream we are forwarding the frames to. If the stream-id on a frame is not set, netty handles sending the + // stream-id on the frame to the target stream's stream-id. + streamFrame.stream(null); + + peerStream.writeAndFlush(streamFrame).addListener(closeInboundOnPeerFailure); + } + + @Override + public void channelInactive(final ChannelHandlerContext ctx) { + peerStream.close(); + } + + @Override + public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) { + logger.warn("Exception proxying frames", cause); + ctx.close(); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedEventLoopGroup.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedEventLoopGroup.java new file mode 100644 index 000000000..5c77b5df3 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedEventLoopGroup.java @@ -0,0 +1,30 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.dropwizard.lifecycle.Managed; +import io.netty.channel.EventLoopGroup; + +/** + * A wrapper for a Netty {@link EventLoopGroup} that implements Dropwizard's {@link Managed} interface, allowing + * Dropwizard to manage the lifecycle of the event loop group. + */ +public class ManagedEventLoopGroup implements Managed { + + private final T eventLoopGroup; + + public ManagedEventLoopGroup(final T eventLoopGroup) { + this.eventLoopGroup = eventLoopGroup; + } + + @Override + public void stop() throws Exception { + this.eventLoopGroup.shutdownGracefully().await(); + } + + public T getEventLoopGroup() { + return eventLoopGroup; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedNioEventLoopGroup.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedNioEventLoopGroup.java deleted file mode 100644 index 06d3e97db..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedNioEventLoopGroup.java +++ /dev/null @@ -1,16 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import io.dropwizard.lifecycle.Managed; -import io.netty.channel.nio.NioEventLoopGroup; - -/** - * A wrapper for a Netty {@link NioEventLoopGroup} that implements Dropwizard's {@link Managed} interface, allowing - * Dropwizard to manage the lifecycle of the event loop group. - */ -public class ManagedNioEventLoopGroup extends NioEventLoopGroup implements Managed { - - @Override - public void stop() throws Exception { - this.shutdownGracefully().await(); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/OmnibusConnectionCounterHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/OmnibusConnectionCounterHandler.java new file mode 100644 index 000000000..ef4199a93 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/OmnibusConnectionCounterHandler.java @@ -0,0 +1,38 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Metrics; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; +import java.util.concurrent.atomic.AtomicLong; + +@ChannelHandler.Sharable +public class OmnibusConnectionCounterHandler extends ChannelInboundHandlerAdapter { + private final AtomicLong openConnections; + private final Counter acceptedConnectionsCounter = + Metrics.counter(MetricsUtil.name(OmnibusConnectionCounterHandler.class, "connectionsAccepted")); + + public OmnibusConnectionCounterHandler() { + openConnections = + Metrics.gauge(MetricsUtil.name(OmnibusConnectionCounterHandler.class, "openConnections"), new AtomicLong()); + } + + @Override + public void channelRegistered(ChannelHandlerContext ctx) throws Exception { + acceptedConnectionsCounter.increment(); + openConnections.incrementAndGet(); + super.channelRegistered(ctx); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + openConnections.decrementAndGet(); + super.channelInactive(ctx); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/OmnibusExceptionHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/OmnibusExceptionHandler.java new file mode 100644 index 000000000..325b52c9b --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/OmnibusExceptionHandler.java @@ -0,0 +1,55 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.micrometer.core.instrument.Metrics; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import java.util.List; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; + +/// A handler that closes the channel on an exception and records errors in a counter. This should be placed at the tail +/// of pipelines to catch uncaught exceptions gracefully +@ChannelHandler.Sharable +public class OmnibusExceptionHandler extends ChannelInboundHandlerAdapter { + + private static final Logger logger = LoggerFactory.getLogger(OmnibusExceptionHandler.class); + + private static final String UNCAUGHT_EXCEPTION_COUNTER_NAME = MetricsUtil.name(OmnibusExceptionHandler.class, + "uncaughtException"); + + private final String channelName; + private final List> expectedExceptions; + + public OmnibusExceptionHandler(final String channelName, final List> expectedExceptions) { + this.channelName = channelName; + this.expectedExceptions = expectedExceptions; + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + Metrics.counter(UNCAUGHT_EXCEPTION_COUNTER_NAME, + "channelName", channelName, + "exceptionClass", cause.getClass().getSimpleName()) + .increment(); + + // There are 'expected' ways to get exceptions on a channel (e.g. client disconnects) so we only log them at debug. + if (expectedException(cause)) { + logger.debug("uncaught exception on channel {}", channelName, cause); + } else { + logger.warn("unexpected uncaught exception on channel {}", channelName, cause); + } + ctx.close(); + } + + private boolean expectedException(final Throwable exception) { + return expectedExceptions + .stream() + .anyMatch(expectedException -> expectedException.isInstance(exception)); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/OmnibusH2Server.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/OmnibusH2Server.java new file mode 100644 index 000000000..7bfa7ef95 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/OmnibusH2Server.java @@ -0,0 +1,176 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net; + +import com.google.common.annotations.VisibleForTesting; +import io.dropwizard.lifecycle.Managed; +import io.micrometer.core.instrument.Metrics; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufAllocatorMetricProvider; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.SimpleUserEventChannelHandler; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.handler.codec.DecoderException; +import io.netty.handler.codec.http2.Http2Exception; +import io.netty.handler.codec.http2.Http2FrameCodecBuilder; +import io.netty.handler.codec.http2.Http2MultiplexHandler; +import io.netty.handler.codec.http2.Http2Settings; +import io.netty.handler.codec.http2.Http2StreamChannel; +import io.netty.handler.ssl.ApplicationProtocolNames; +import io.netty.handler.ssl.ApplicationProtocolNegotiationHandler; +import io.netty.handler.ssl.SniHandler; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.timeout.IdleStateEvent; +import io.netty.handler.timeout.IdleStateHandler; +import io.netty.util.Mapping; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.net.SocketException; +import java.time.Duration; +import java.util.List; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; + +/// An HTTP/2 server that proxies H2 streams to configurable backends via path-based routing +public class OmnibusH2Server implements Managed { + + private static final Logger logger = LoggerFactory.getLogger(OmnibusH2Server.class); + + private static final OmnibusConnectionCounterHandler CONNECT_COUNTER = new OmnibusConnectionCounterHandler(); + private static final OmnibusExceptionHandler HANDSHAKE_EXCEPTION_HANDLER = + new OmnibusExceptionHandler("omnibus-handshake", List.of(SocketException.class, DecoderException.class, IOException.class)); + private static final OmnibusExceptionHandler SESSION_EXCEPTION_HANDLER = + new OmnibusExceptionHandler("omnibus-session", List.of(Http2Exception.class)); + private static final String IDLE_DISCONNECT_COUNTER_NAME = MetricsUtil.name(OmnibusH2Server.class, "idleDisconnect"); + + private final @Nullable Mapping sslContextBySni; + private final OmnibusRouter router; + private final Duration idleTimeout; + private final DefaultEventLoopGroup localEventLoopGroup; + private final NioEventLoopGroup nioEventLoopGroup; + private final SocketAddress bindAddress; + + private Channel serverChannel; + + /// Create an omnibus server + /// + /// @param sslContextBySni If not null, a mapping between domain (SNI) and the appropriate SslContext to use for + /// that SNI. If null, the server will not include TLS (h2c with prior-knowledge) + /// @param nioEventLoopGroup Event loop to use for all NIO channel pipelines + /// @param localEventLoopGroup Event loop to use for all local channel pipelines + /// @param bindAddress The address the server should listen on + /// @param router How the server should select backends based on request paths + public OmnibusH2Server( + final @Nullable Mapping sslContextBySni, + final NioEventLoopGroup nioEventLoopGroup, + final DefaultEventLoopGroup localEventLoopGroup, + final SocketAddress bindAddress, + final OmnibusRouter router, + final Duration idleTimeout) { + this.sslContextBySni = sslContextBySni; + this.nioEventLoopGroup = nioEventLoopGroup; + this.localEventLoopGroup = localEventLoopGroup; + this.bindAddress = bindAddress; + this.router = router; + this.idleTimeout = idleTimeout; + } + + @Override + public void start() throws Exception { + if (this.sslContextBySni == null) { + logger.warn("No SSL configuration provided for OmnibusH2Server, serving h2c"); + } + + if (ByteBufAllocator.DEFAULT instanceof ByteBufAllocatorMetricProvider alloc) { + Metrics.gauge(MetricsUtil.name(OmnibusH2Server.class, "nettyUsedDirectMemory"), + alloc, + allocator -> allocator.metric().usedDirectMemory()); + Metrics.gauge(MetricsUtil.name(OmnibusH2Server.class, "nettyUsedHeapMemory"), + alloc, + allocator -> allocator.metric().usedHeapMemory()); + } + + final ServerBootstrap bootstrap = new ServerBootstrap() + .group(nioEventLoopGroup) + .channel(NioServerSocketChannel.class) + .childOption(ChannelOption.SO_KEEPALIVE, true) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(final SocketChannel ch) { + ch.pipeline().addLast(new IdleStateHandler(0, 0, idleTimeout.toMillis(), TimeUnit.MILLISECONDS)); + ch.pipeline().addLast(new SimpleUserEventChannelHandler() { + @Override + protected void eventReceived(final ChannelHandlerContext ctx, final IdleStateEvent evt) { + Metrics.counter(IDLE_DISCONNECT_COUNTER_NAME, "type", evt.state().name()).increment(); + ctx.close(); + } + }); + ch.pipeline().addLast(CONNECT_COUNTER); + ch.pipeline().addLast(new ProxyProtocolHandler()); + ch.pipeline().addLast(new ProxyMessageAttributeSetterHandler()); + if (sslContextBySni == null) { + configureH2Pipeline(ch.pipeline()); + } else { + ch.pipeline().addLast(new SniHandler(sslContextBySni)); + ch.pipeline().addLast(new ApplicationProtocolNegotiationHandler(ApplicationProtocolNames.HTTP_2) { + @Override + protected void configurePipeline(final ChannelHandlerContext ctx, final String protocol) { + if (!ApplicationProtocolNames.HTTP_2.equals(protocol)) { + // HTTP/2 should be enforced by our ALPN settings + logger.error("Unsupported protocol negotiated: {}, closing connection", protocol); + ctx.close(); + return; + } + configureH2Pipeline(ctx.pipeline()); + } + }); + ch.pipeline().addLast(HANDSHAKE_EXCEPTION_HANDLER); + } + } + }); + + serverChannel = bootstrap.bind(bindAddress).sync().channel(); + logger.info("Omnibus server listening on {}", getLocalAddress()); + } + + @VisibleForTesting + InetSocketAddress getLocalAddress() { + return (InetSocketAddress) serverChannel.localAddress(); + } + + @Override + public void stop() { + if (serverChannel != null) { + logger.info("Stopping omnibus server"); + serverChannel.close().syncUninterruptibly(); + logger.info("Omnibus server stopped"); + } + } + + private void configureH2Pipeline(final ChannelPipeline pipeline) { + // Advertise support for RFC-8441 extended connect + final Http2Settings settings = Http2Settings.defaultSettings().connectProtocolEnabled(true); + pipeline.addLast(Http2FrameCodecBuilder.forServer().initialSettings(settings).build()); + pipeline.addLast(new Http2MultiplexHandler(new ChannelInitializer() { + @Override + protected void initChannel(final Http2StreamChannel ch) { + ch.pipeline().addLast(new OmnibusH2StreamHandler(nioEventLoopGroup, localEventLoopGroup, router)); + } + })); + pipeline.addLast(SESSION_EXCEPTION_HANDLER); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/OmnibusH2StreamHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/OmnibusH2StreamHandler.java new file mode 100644 index 000000000..81c4cc53e --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/OmnibusH2StreamHandler.java @@ -0,0 +1,230 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net; + +import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; + +import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Timer; +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.http2.DefaultHttp2Headers; +import io.netty.handler.codec.http2.DefaultHttp2HeadersFrame; +import io.netty.handler.codec.http2.DefaultHttp2ResetFrame; +import io.netty.handler.codec.http2.Http2Error; +import io.netty.handler.codec.http2.Http2FrameCodecBuilder; +import io.netty.handler.codec.http2.Http2HeadersFrame; +import io.netty.handler.codec.http2.Http2MultiplexHandler; +import io.netty.handler.codec.http2.Http2Settings; +import io.netty.handler.codec.http2.Http2StreamChannel; +import io.netty.handler.codec.http2.Http2StreamChannelBootstrap; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.List; +import java.util.Optional; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/// Handler added on each newly created [Http2StreamChannel] on an H2 connection. Inspects the [Http2HeadersFrame] and +/// determines which backend to forward the stream to, and then proxies frames to and from the backend. +/// +/// When this handler receives an H2 header for a new stream-id=X on our parent H2 connection it will +/// - Receive the stream-id=X header and check the path to determine the correct backend +/// - Make a new H2 connection to the backend +/// - Forward the header with stream-id=Y to the backend +/// - Install a [H2FrameProxyHandler] on the backend stream pipeline that forwards the received stream-id=Y frames from +/// the backend back to the client on stream-id=X +/// - Install a [H2FrameProxyHandler] on the client stream pipeline forwards the received stream-id=X frames from the +/// client to the backend on stream-id=Y +public class OmnibusH2StreamHandler extends ChannelInboundHandlerAdapter { + + private static final Logger logger = LoggerFactory.getLogger(OmnibusH2StreamHandler.class); + + private static final OmnibusExceptionHandler BACKEND_CONNECTION_EXCEPTION_HANDLER = + new OmnibusExceptionHandler("backend-connection", List.of()); + private static final String BACKEND_STREAM_COUNTER_NAME = name(OmnibusH2StreamHandler.class, "backendStream"); + private static final String BACKEND_CONNECT_DURATION_NAME = name(OmnibusH2StreamHandler.class, + "backendConnectDuration"); + private static final String BACKEND_TAG = "backend"; + + private final OmnibusRouter router; + + private final DefaultEventLoopGroup localEventLoopGroup; + private final NioEventLoopGroup nioEventLoopGroup; + + public OmnibusH2StreamHandler( + final NioEventLoopGroup nioEventLoopGroup, + final DefaultEventLoopGroup localEventLoopGroup, + final OmnibusRouter router) { + this.router = router; + this.localEventLoopGroup = localEventLoopGroup; + this.nioEventLoopGroup = nioEventLoopGroup; + } + + @Override + public void channelRead(final ChannelHandlerContext ctx, final Object msg) { + if (!(msg instanceof Http2HeadersFrame headersFrame)) { + logger.warn("Expected initial HEADERS frame but got {}", msg.getClass().getSimpleName()); + ReferenceCountUtil.release(msg); + ctx.close(); + return; + } + + // We don't expect headers frames to come with manually managed memory attached. Assert this in case this changes + // in the future, but for now we don't have to worry about freeing the headers frame + assert !(headersFrame instanceof ReferenceCounted); + + // Disable reading from the client because we want to wait until we make the backend connection and install the + // forwarding handler before processing any more frames. + ctx.channel().config().setAutoRead(false); + + // Select the target backend based on the path + final String path = Optional.ofNullable(headersFrame.headers().path()).map(CharSequence::toString).orElse(""); + final SocketAddress target = router.match(path); + final String backendTag = target.toString(); + + Metrics.counter(BACKEND_STREAM_COUNTER_NAME, BACKEND_TAG, backendTag).increment(); + + // Set X-Forwarded-For from the PROXY protocol header if present, otherwise via the remote address + final InetAddress proxyRemoteAddress = ctx.channel().parent() + .attr(ProxyMessageAttributeSetterHandler.PROXY_REMOTE_ADDRESS) + .get(); + headersFrame.headers().set("x-forwarded-for", proxyRemoteAddress != null + ? proxyRemoteAddress.getHostAddress() + : ((InetSocketAddress) ctx.channel().remoteAddress()).getHostString()); + + // Make a new H2 connection to the target backend + final Timer.Sample connectSample = Timer.start(); + new Bootstrap() + .group(selectEventLoop(ctx, target)) + .channel(target instanceof LocalAddress ? LocalChannel.class : NioSocketChannel.class) + .handler(new ChannelInitializer<>() { + @Override + protected void initChannel(final Channel ch) { + ch.pipeline() + .addLast(Http2FrameCodecBuilder.forClient().initialSettings(Http2Settings.defaultSettings()).build()); + // Http2MultiplexHandler takes handler that is added to new inbound streams. A client Http2MultiplexHandler + // like we're defining here should never receive an inbound H2 stream so we can just pass a noop handler + ch.pipeline().addLast(new Http2MultiplexHandler(new NoopInboundStreamHandler())); + ch.pipeline().addLast(BACKEND_CONNECTION_EXCEPTION_HANDLER); + } + }) + .connect(target) + .addListener((ChannelFuture connectFuture) -> { + connectSample.stop(Timer.builder(BACKEND_CONNECT_DURATION_NAME) + .tag(BACKEND_TAG, backendTag) + .tag("outcome", connectFuture.isSuccess() ? "success" : "failure") + .register(Metrics.globalRegistry)); + + if (!connectFuture.isSuccess()) { + // Close the client stream with a 502: Bad Gateway if the backend wasn't available + logger.warn("Failed to connect to backend {}", target, connectFuture.cause()); + ctx.channel() + .writeAndFlush(new DefaultHttp2HeadersFrame( + new DefaultHttp2Headers().status("502"), true)) + .addListener(ChannelFutureListener.CLOSE); + return; + } + + // Connected, open a new H2 stream to the backend so we can proxy the client's frames + logger.trace("Opening a HTTP/2 stream to the backend {}", target); + final Channel backendConnection = connectFuture.channel(); + createBackendProxyStream(ctx, backendConnection, headersFrame); + }); + } + + /// Create a proxy stream on the provided `backendConnection` that forwards H2 frames to/from the client H2 stream. + /// + /// @param clientStreamCtx The context for a client H2 stream that targets the backend + /// @param backendConnection An established H2 connection [Channel], on which a new h2 stream will be opened + /// @param headersFrame The first `headersFrame` from the client h2 stream that should be forwarded to the new + /// backend stream + private void createBackendProxyStream( + final ChannelHandlerContext clientStreamCtx, + final Channel backendConnection, + final Http2HeadersFrame headersFrame) { + new Http2StreamChannelBootstrap(backendConnection) + // Forwards response frames from the backend back to the client stream + .handler(new H2FrameProxyHandler(clientStreamCtx.channel(), "responseStream")) + .open() + .addListener((io.netty.util.concurrent.Future streamFuture) -> { + if (!streamFuture.isSuccess()) { + logger.warn("Failed to open backend stream", streamFuture.cause()); + clientStreamCtx.channel() + .writeAndFlush(new DefaultHttp2ResetFrame(Http2Error.INTERNAL_ERROR)) + .addListener(ChannelFutureListener.CLOSE); + backendConnection.close(); + return; + } + + final Http2StreamChannel backendStream = streamFuture.getNow(); + + // Close the entire H2 connection whenever the stream we just opened closes. We only plan on using + // a single stream on this connection. + backendStream.closeFuture().addListener(_ -> backendConnection.close()); + + // We're going to modify the inbound H2 stream channel, which runs on a different eventloop than the + // outbound channel we've made to the backend. We have to submit our updates back to the inbound + // channel's event loop for thread safety + clientStreamCtx.channel().eventLoop().execute(() -> { + + if (!clientStreamCtx.channel().isActive()) { + // The client disconnected already and the client pipeline is already torn down. + backendConnection.close(); + return; + } + + // Install proxy on client stream, remove this handler, then fire the buffered headers through the proxy + clientStreamCtx.pipeline().replace( + OmnibusH2StreamHandler.this, + "backend-to-client-proxy", + new H2FrameProxyHandler(backendStream, "requestStream")); + clientStreamCtx.channel().pipeline().fireChannelRead(headersFrame); + + // Resume inbound reads, which should now be forwarded + clientStreamCtx.channel().config().setAutoRead(true); + }); + }); + } + + private EventLoopGroup selectEventLoop(final ChannelHandlerContext inboundCtx, SocketAddress target) { + final boolean localInbound = inboundCtx.channel() instanceof LocalChannel; + final boolean localTarget = target instanceof LocalAddress; + + // If the inbound eventloop matches the target type, we can just reuse the inbound's event loop + if (localInbound == localTarget) { + return inboundCtx.channel().eventLoop(); + } + + return localTarget ? this.localEventLoopGroup : this.nioEventLoopGroup; + } + + private static class NoopInboundStreamHandler extends ChannelInboundHandlerAdapter { + @Override + public void channelRegistered(final ChannelHandlerContext ctx) { + logger.error("Inbound stream handler was registered when no inbound streams expected"); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + logger.error("Received unexpected message: {} on inbound stream from backend", msg); + super.channelRead(ctx, msg); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/OmnibusRouter.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/OmnibusRouter.java new file mode 100644 index 000000000..1365a0c2b --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/OmnibusRouter.java @@ -0,0 +1,30 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net; + +import java.net.SocketAddress; +import java.util.List; + +public class OmnibusRouter { + + public record OmnibusRoute(String prefix, SocketAddress backend) {} + + private final List prefixRoutes; + private final SocketAddress defaultBackend; + + public OmnibusRouter(final List prefixRoutes, final SocketAddress defaultBackend) { + this.prefixRoutes = prefixRoutes; + this.defaultBackend = defaultBackend; + } + + SocketAddress match(final String path) { + for (final OmnibusRoute route : prefixRoutes) { + if (path.startsWith(route.prefix)) { + return route.backend; + } + } + return defaultBackend; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyMessageAttributeSetterHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyMessageAttributeSetterHandler.java new file mode 100644 index 000000000..97d13e692 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyMessageAttributeSetterHandler.java @@ -0,0 +1,42 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.haproxy.HAProxyMessage; +import io.netty.util.AttributeKey; +import java.net.InetAddress; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/// Reads the decoded [HAProxyMessage], stores the source address as a channel attribute, and removes itself. +class ProxyMessageAttributeSetterHandler extends ChannelInboundHandlerAdapter { + private static final Logger logger = LoggerFactory.getLogger(ProxyMessageAttributeSetterHandler.class); + + /// Attribute for the remote address extracted from a proxy protocol header + static final AttributeKey PROXY_REMOTE_ADDRESS = AttributeKey.newInstance("proxyRemoteAddress"); + + @Override + public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception { + if (!(msg instanceof HAProxyMessage proxyMessage)) { + ctx.pipeline().remove(this); + ctx.fireChannelRead(msg); + return; + } + + try { + final String sourceAddress = proxyMessage.sourceAddress(); + if (sourceAddress != null) { + ctx.channel().attr(PROXY_REMOTE_ADDRESS).set(InetAddress.getByName(sourceAddress)); + } else { + logger.warn("PROXY protocol message has no source address"); + } + } finally { + proxyMessage.release(); + ctx.pipeline().remove(this); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolHandler.java new file mode 100644 index 000000000..7b4b3b3c3 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolHandler.java @@ -0,0 +1,48 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net; + +import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; + +import io.micrometer.core.instrument.Metrics; +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.ProtocolDetectionResult; +import io.netty.handler.codec.haproxy.HAProxyMessageDecoder; +import io.netty.handler.codec.haproxy.HAProxyProtocolVersion; +import java.util.List; + +class ProxyProtocolHandler extends ByteToMessageDecoder { + + private static final String PROXY_PROTOCOL_DETECTED_NAME = + name(ProxyProtocolHandler.class, "proxyProtocol"); + + @Override + protected void decode(final ChannelHandlerContext ctx, final ByteBuf in, final List out) { + + // This does not advance the read index, so the bytes we accumulate via ByteToMessageDecoder are always forwarded + // once we get enough (either to an HAProxyMessageDecoder, or just to the rest of the pipeline) + final ProtocolDetectionResult detected = HAProxyMessageDecoder.detectProtocol(in); + + switch (detected.state()) { + case NEEDS_MORE_DATA: + break; + case DETECTED: + // There is a valid proxy-protocol header. Replace ourselves with the actual decoder (which will forward our + // accumulated bytes to the decoder via handlerRemoved) + Metrics.counter(PROXY_PROTOCOL_DETECTED_NAME, "detected", "true").increment(); + ctx.pipeline().replace(this, "haproxy-decoder", new HAProxyMessageDecoder()); + break; + case INVALID: + // No header, we can just forward any bytes we've accumulated. + Metrics.counter(PROXY_PROTOCOL_DETECTED_NAME, "detected", "false").increment(); + ctx.pipeline().remove(this); + break; + } + } + + +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/SniMapper.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/SniMapper.java new file mode 100644 index 000000000..cb5d9299d --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/SniMapper.java @@ -0,0 +1,164 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net; + +import com.google.common.annotations.VisibleForTesting; +import io.netty.handler.ssl.ApplicationProtocolConfig; +import io.netty.handler.ssl.ApplicationProtocolNames; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.util.Mapping; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.security.Key; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.UnrecoverableKeyException; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; +import java.security.cert.CertificateParsingException; +import java.security.cert.X509Certificate; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class SniMapper { + + private static final Logger logger = LoggerFactory.getLogger(SniMapper.class); + + private SniMapper() { + } + + /// Build a [Mapping] from a [KeyStore] that maps from domain (via SAN) to [io.netty.handler.ssl.SslContext] that + /// can be used with an [io.netty.handler.ssl.SniHandler]. The provided keystore may contain multiple certificates + /// for a single domain, all matching certificates will be included in the corresponding SslContext. The domain for + /// a certificate is only determined by the SAN and all certificates must have a SAN. The returned [Mapping] returns + /// an arbitrary set of certificates if none of the certificates in the keystore match a requested domain, as + /// permitted by RFC-6066. + /// + /// @param keyStorePath The path to the [KeyStore] + /// @param keyStorePassword The password for the keyStore + /// @return A [Mapping] that maps domains to the corresponding [SslContext] containing the certificates for that + /// domain + public static Mapping buildSniMapping(final String keyStorePath, final String keyStorePassword) + throws IOException { + try (final FileInputStream fis = new FileInputStream(keyStorePath)) { + return buildSniMapping(fis, keyStorePassword); + } + } + + @VisibleForTesting + static Mapping buildSniMapping(final InputStream keyStore, final String keyStorePassword) + throws IOException { + try { + final Map domainKeyStores = partitionByDomain(keyStore, keyStorePassword.toCharArray()); + final Map sslContextsByDomain = new HashMap<>(); + for (final Map.Entry entry : domainKeyStores.entrySet()) { + final KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + kmf.init(entry.getValue(), keyStorePassword.toCharArray()); + sslContextsByDomain.put(entry.getKey(), buildSslContext(kmf)); + } + + // Netty expects the SNI mapping to always return an SslContext. Per RFC-6066 it's valid to continue the handshake + // on an SNI mismatch, it's the client's responsibility to check the returned certificate's SNI. We sort first so + // our choice of certificate is deterministic. + final SslContext defaultSslContext = sslContextsByDomain.entrySet().stream() + .min(Map.Entry.comparingByKey()) + .orElseThrow(() -> new IllegalArgumentException("Key store contained no certificates")) + .getValue(); + + logger.info("Loaded TLS contexts for domains: {}", sslContextsByDomain.keySet()); + return hostname -> sslContextsByDomain.getOrDefault(hostname, defaultSslContext); + } catch (NoSuchAlgorithmException | KeyStoreException | CertificateException | UnrecoverableKeyException e) { + throw new IOException("Failed to load keystore", e); + } + } + + private static SslContext buildSslContext(final KeyManagerFactory kmf) throws SSLException { + return SslContextBuilder.forServer(kmf) + .applicationProtocolConfig(new ApplicationProtocolConfig( + ApplicationProtocolConfig.Protocol.ALPN, + ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE, + ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT, + ApplicationProtocolNames.HTTP_2)) + .protocols("TLSv1.3") + .build(); + } + + private static Map partitionByDomain(final InputStream keystoreStream, final char[] keystorePassword) + throws KeyStoreException, CertificateException, UnrecoverableKeyException, IOException, NoSuchAlgorithmException { + + final KeyStore keyStore = KeyStore.getInstance("PKCS12"); + keyStore.load(keystoreStream, keystorePassword); + + // Group key entries by the domain(s) in each certificate's SANs + final Map domainKeyStores = new HashMap<>(); + for (final String alias : Collections.list(keyStore.aliases())) { + if (!keyStore.isKeyEntry(alias)) { + continue; + } + + final Certificate[] chain = keyStore.getCertificateChain(alias); + if (chain == null || chain.length == 0) { + continue; + } + + final X509Certificate leaf = (X509Certificate) chain[0]; + final Key key = keyStore.getKey(alias, keystorePassword); + + for (final String domain : getDnsNames(leaf)) { + domainKeyStores + .computeIfAbsent(domain, _ -> newEmptyKeyStore()) + .setKeyEntry(alias, key, keystorePassword, chain); + } + } + + if (domainKeyStores.isEmpty()) { + throw new IOException("Keystore contains no usable key entries with DNS names"); + } + return domainKeyStores; + + } + + private static KeyStore newEmptyKeyStore() { + try { + final KeyStore ks = KeyStore.getInstance("PKCS12"); + ks.load(null, null); + return ks; + } catch (KeyStoreException | IOException | NoSuchAlgorithmException | CertificateException e) { + // All Java runtime implementations are required to support PKCS12, and we aren't loading anything from disk + // so an exception here is impossible. + throw new AssertionError("Failed to create empty keystore", e); + } + } + + /// Extract all DNS-type SAN names on this certificate + private static List getDnsNames(final X509Certificate cert) throws CertificateParsingException, IOException { + final Collection> sans = cert.getSubjectAlternativeNames(); + if (sans == null) { + throw new IOException("Certificate did not have SAN extension"); + } + final List dnsSans = sans.stream() + // GeneralName type 2 = dNSName. See getSubjectAlternativeNames + .filter(san -> (int) san.getFirst() == 2) + .map(san -> (String) san.get(1)) + .map(s -> s.toLowerCase(Locale.ROOT)) + .toList(); + if (dnsSans.isEmpty()) { + throw new IOException("Certificate did not have a DNS SAN entry"); + } + return dnsSans; + } + +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/jetty/JettyHttpConfigurationCustomizer.java b/service/src/main/java/org/whispersystems/textsecuregcm/jetty/JettyHttpConfigurationCustomizer.java index 3415c0086..2f44aefc6 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/jetty/JettyHttpConfigurationCustomizer.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/jetty/JettyHttpConfigurationCustomizer.java @@ -40,8 +40,6 @@ public class JettyHttpConfigurationCustomizer implements Container.Listener, Lif httpConfiguration.setNotifyRemoteAsyncErrors(false); } } - - c.addBean(new JettyConnectionMetrics(Metrics.globalRegistry)); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsApplicationEventListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsApplicationEventListener.java index 5c58de79a..2d93ec012 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsApplicationEventListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsApplicationEventListener.java @@ -14,7 +14,7 @@ import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; /** * Delegates request events to a listener that captures and reports request-level metrics. * - * @see MetricsHttpChannelListener + * @see MetricsHttpEventHandler * @see MetricsRequestEventListener */ public class MetricsApplicationEventListener implements ApplicationEventListener { @@ -23,7 +23,7 @@ public class MetricsApplicationEventListener implements ApplicationEventListener public MetricsApplicationEventListener(final TrafficSource trafficSource, final ClientReleaseManager clientReleaseManager) { if (trafficSource == TrafficSource.HTTP) { - throw new IllegalArgumentException("Use " + MetricsHttpChannelListener.class.getName() + " for HTTP traffic"); + throw new IllegalArgumentException("Use " + MetricsHttpEventHandler.class.getName() + " for HTTP traffic"); } this.metricsRequestEventListener = new MetricsRequestEventListener(trafficSource, clientReleaseManager); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListener.java deleted file mode 100644 index 25596c46a..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListener.java +++ /dev/null @@ -1,226 +0,0 @@ -/* - * Copyright 2024 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.metrics; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.net.HttpHeaders; -import io.dropwizard.core.setup.Environment; -import io.micrometer.core.instrument.MeterRegistry; -import io.micrometer.core.instrument.Metrics; -import io.micrometer.core.instrument.Tag; -import io.micrometer.core.instrument.Tags; -import jakarta.ws.rs.container.ContainerRequestContext; -import jakarta.ws.rs.container.ContainerResponseContext; -import jakarta.ws.rs.container.ContainerResponseFilter; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Locale; -import java.util.Optional; -import java.util.Set; -import javax.annotation.Nullable; -import org.apache.commons.lang3.StringUtils; -import org.eclipse.jetty.server.Connector; -import org.eclipse.jetty.server.HttpChannel; -import org.eclipse.jetty.server.Request; -import org.eclipse.jetty.util.component.Container; -import org.eclipse.jetty.util.component.LifeCycle; -import org.glassfish.jersey.server.ExtendedUriInfo; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; -import org.whispersystems.textsecuregcm.util.logging.UriInfoUtil; -import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; -import org.whispersystems.textsecuregcm.util.ua.UserAgent; -import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; - -/** - * Gathers and reports HTTP request metrics at the Jetty container level, which sits above Jersey. In order to get - * templated Jersey request paths, it implements {@link jakarta.ws.rs.container.ContainerResponseFilter}, in order to give - * itself access to the template. It is limited to {@link TrafficSource#HTTP} requests. - *

- * It implements {@link LifeCycle.Listener} without overriding methods, so that it can be an event listener that - * Dropwizard will attach to the container—the {@link Container.Listener} implementation is where it attaches - * itself to any {@link Connector}s. - * - * @see MetricsRequestEventListener - */ -public class MetricsHttpChannelListener implements HttpChannel.Listener, Container.Listener, LifeCycle.Listener, - ContainerResponseFilter { - - private static final Set EXPECTED_HTTP_METHODS = - Set.of("GET", "HEAD", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"); - - private static final Logger logger = LoggerFactory.getLogger(MetricsHttpChannelListener.class); - - private record RequestInfo(String path, String method, int statusCode, @Nullable String userAgent) { - } - - private final ClientReleaseManager clientReleaseManager; - private final Set servletPaths; - - // Use the same counter namespace as MetricsRequestEventListener for continuity - public static final String REQUEST_COUNTER_NAME = MetricsRequestEventListener.REQUEST_COUNTER_NAME; - public static final String REQUESTS_BY_VERSION_COUNTER_NAME = MetricsRequestEventListener.REQUESTS_BY_VERSION_COUNTER_NAME; - @VisibleForTesting - static final String RESPONSE_BYTES_COUNTER_NAME = MetricsRequestEventListener.RESPONSE_BYTES_COUNTER_NAME; - @VisibleForTesting - static final String REQUEST_BYTES_COUNTER_NAME = MetricsRequestEventListener.REQUEST_BYTES_COUNTER_NAME; - @VisibleForTesting - static final String URI_INFO_PROPERTY_NAME = MetricsHttpChannelListener.class.getName() + ".uriInfo"; - - @VisibleForTesting - static final String PATH_TAG = "path"; - - @VisibleForTesting - static final String METHOD_TAG = "method"; - - @VisibleForTesting - static final String STATUS_CODE_TAG = "status"; - - @VisibleForTesting - static final String TRAFFIC_SOURCE_TAG = "trafficSource"; - - private final MeterRegistry meterRegistry; - - - public MetricsHttpChannelListener(final ClientReleaseManager clientReleaseManager, final Set servletPaths) { - this(Metrics.globalRegistry, clientReleaseManager, servletPaths); - } - - @VisibleForTesting - MetricsHttpChannelListener(final MeterRegistry meterRegistry, final ClientReleaseManager clientReleaseManager, - final Set servletPaths) { - this.meterRegistry = meterRegistry; - this.clientReleaseManager = clientReleaseManager; - this.servletPaths = servletPaths; - } - - public void configure(final Environment environment) { - // register as ContainerResponseFilter - environment.jersey().register(this); - - // hook into lifecycle events, to react to the Connector being added - environment.lifecycle().addEventListener(this); - } - - @Override - public void onRequestFailure(final Request request, final Throwable failure) { - - if (logger.isDebugEnabled()) { - final RequestInfo requestInfo = getRequestInfo(request); - - logger.debug("Request failure: {} {} ({}) [{}] ", - requestInfo.method(), - requestInfo.path(), - requestInfo.userAgent(), - requestInfo.statusCode(), failure); - } - } - - @Override - public void onResponseFailure(Request request, Throwable failure) { - - if (failure instanceof org.eclipse.jetty.io.EofException) { - // the client disconnected early - return; - } - - final RequestInfo requestInfo = getRequestInfo(request); - - logger.warn("Response failure: {} {} ({}) [{}] ", - requestInfo.method(), - requestInfo.path(), - requestInfo.userAgent(), - requestInfo.statusCode(), failure); - } - - @Override - public void onComplete(final Request request) { - - final RequestInfo requestInfo = getRequestInfo(request); - - @Nullable final UserAgent userAgent; - { - UserAgent parsedUserAgent; - - try { - parsedUserAgent = UserAgentUtil.parseUserAgentString(requestInfo.userAgent()); - } catch (final UnrecognizedUserAgentException e) { - parsedUserAgent = null; - } - - userAgent = parsedUserAgent; - } - - final Tag platformTag = UserAgentTagUtil.getPlatformTag(userAgent); - - final List tags = new ArrayList<>(5); - tags.add(Tag.of(PATH_TAG, requestInfo.path())); - tags.add(Tag.of(METHOD_TAG, requestInfo.method())); - tags.add(Tag.of(STATUS_CODE_TAG, String.valueOf(requestInfo.statusCode()))); - tags.add(Tag.of(TRAFFIC_SOURCE_TAG, TrafficSource.HTTP.name().toLowerCase())); - tags.add(platformTag); - - final Optional maybeClientVersionTag = - UserAgentTagUtil.getClientVersionTag(userAgent, clientReleaseManager); - - maybeClientVersionTag.ifPresent(tags::add); - - meterRegistry.counter(REQUEST_COUNTER_NAME, tags).increment(); - - meterRegistry.counter(RESPONSE_BYTES_COUNTER_NAME, tags).increment(request.getResponse().getContentCount()); - meterRegistry.counter(REQUEST_BYTES_COUNTER_NAME, tags).increment(request.getContentRead()); - - maybeClientVersionTag.ifPresent(clientVersionTag -> meterRegistry.counter(REQUESTS_BY_VERSION_COUNTER_NAME, - Tags.of(clientVersionTag, platformTag)) - .increment()); - } - - @Override - public void beanAdded(final Container parent, final Object child) { - if (child instanceof Connector connector) { - connector.addBean(this); - } - } - - @Override - public void beanRemoved(final Container parent, final Object child) { - - } - - @Override - public void filter(final ContainerRequestContext requestContext, final ContainerResponseContext responseContext) - throws IOException { - requestContext.setProperty(URI_INFO_PROPERTY_NAME, requestContext.getUriInfo()); - } - - private RequestInfo getRequestInfo(Request request) { - final String path = Optional.ofNullable(request.getAttribute(URI_INFO_PROPERTY_NAME)) - .map(attr -> UriInfoUtil.getPathTemplate((ExtendedUriInfo) attr)) - .orElseGet(() -> - Optional.ofNullable(request.getPathInfo()) - .filter(servletPaths::contains) - .orElse("unknown") - ); - - // Response cannot be null, but its status might not always reflect an actual response status, since it gets - // initialized to 200 - final int status = request.getResponse().getStatus(); - - @Nullable final String userAgent = request.getHeader(HttpHeaders.USER_AGENT); - - return new RequestInfo(path, normalizeMethod(request.getMethod()), status, userAgent); - } - - static String normalizeMethod(@Nullable final String method) { - if (StringUtils.isBlank(method)) { - return "unknown"; - } - - return EXPECTED_HTTP_METHODS.contains(method.toUpperCase(Locale.ROOT)) ? method : "unknown"; - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpEventHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpEventHandler.java new file mode 100644 index 000000000..5be3e7366 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpEventHandler.java @@ -0,0 +1,280 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.metrics; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.net.HttpHeaders; +import io.dropwizard.core.setup.Environment; +import io.micrometer.core.instrument.MeterRegistry; +import io.micrometer.core.instrument.Tag; +import io.micrometer.core.instrument.Tags; +import jakarta.validation.constraints.NotNull; +import jakarta.ws.rs.container.ContainerRequestContext; +import jakarta.ws.rs.container.ContainerResponseContext; +import jakarta.ws.rs.container.ContainerResponseFilter; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import javax.annotation.Nullable; +import org.apache.commons.lang3.StringUtils; +import org.eclipse.jetty.http.HttpFields; +import org.eclipse.jetty.io.Content; +import org.eclipse.jetty.server.Handler; +import org.eclipse.jetty.server.Request; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.server.handler.EventsHandler; +import org.eclipse.jetty.util.component.LifeCycle; +import org.glassfish.jersey.server.ExtendedUriInfo; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; +import org.whispersystems.textsecuregcm.util.logging.UriInfoUtil; +import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; +import org.whispersystems.textsecuregcm.util.ua.UserAgent; +import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; + +/** + * Gathers and reports HTTP request metrics at the Jetty container level, which sits above Jersey. In order to get + * templated Jersey request paths, it adds a {@link jakarta.ws.rs.container.ContainerResponseFilter}, in order to give + * itself access to the template. It is limited to {@link TrafficSource#HTTP} requests. + * + * @see MetricsRequestEventListener + */ +public class MetricsHttpEventHandler extends EventsHandler { + + private static final Logger logger = LoggerFactory.getLogger(MetricsHttpEventHandler.class); + + private static final Set EXPECTED_HTTP_METHODS = + Set.of("GET", "HEAD", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"); + + private final ClientReleaseManager clientReleaseManager; + private final Set servletPaths; + + // Use the same counter namespace as MetricsRequestEventListener for continuity + public static final String REQUEST_COUNTER_NAME = MetricsRequestEventListener.REQUEST_COUNTER_NAME; + public static final String REQUESTS_BY_VERSION_COUNTER_NAME = MetricsRequestEventListener.REQUESTS_BY_VERSION_COUNTER_NAME; + @VisibleForTesting + static final String RESPONSE_BYTES_COUNTER_NAME = MetricsRequestEventListener.RESPONSE_BYTES_COUNTER_NAME; + @VisibleForTesting + static final String REQUEST_BYTES_COUNTER_NAME = MetricsRequestEventListener.REQUEST_BYTES_COUNTER_NAME; + @VisibleForTesting + static final String REQUEST_INFO_PROPERTY_NAME = MetricsHttpEventHandler.class.getName() + ".requestInfo"; + + @VisibleForTesting + static final String PATH_TAG = "path"; + + @VisibleForTesting + static final String METHOD_TAG = "method"; + + @VisibleForTesting + static final String STATUS_CODE_TAG = "status"; + + @VisibleForTesting + static final String TRAFFIC_SOURCE_TAG = "trafficSource"; + + private final MeterRegistry meterRegistry; + + @VisibleForTesting + MetricsHttpEventHandler( + final Handler handler, + final MeterRegistry meterRegistry, + final ClientReleaseManager clientReleaseManager, + final Set servletPaths) { + super(handler); + + this.meterRegistry = meterRegistry; + this.clientReleaseManager = clientReleaseManager; + this.servletPaths = servletPaths; + } + + /** + * Configure a {@link MetricsHttpEventHandler} + * + * @param environment A dropwizard {@link org.eclipse.jetty.util.component.Environment} + * @param meterRegistry The meter registry to register metrics with + * @param clientReleaseManager A {@link ClientReleaseManager} that determines what tags to include with metrics + * @param servletPaths An allow-list of paths to include in metric tags for requests that are handled by above + * Jersey + */ + public static void configure(final Environment environment, final MeterRegistry meterRegistry, + final ClientReleaseManager clientReleaseManager, final Set servletPaths) { + // register a filter that will set the initial request info + environment.jersey().register(new SetInfoRequestFilter()); + + // hook into lifecycle events, to react to the Connector being added + environment.lifecycle().addEventListener(new LifeCycle.Listener() { + @Override + public void lifeCycleStarting(LifeCycle event) { + if (event instanceof Server server) { + server.setHandler( + new MetricsHttpEventHandler(server.getHandler(), meterRegistry, clientReleaseManager, servletPaths)); + } + } + }); + } + + private void onResponseFailure(Request request, int status, Throwable failure) { + + if (failure instanceof org.eclipse.jetty.io.EofException) { + // the client disconnected early + return; + } + + final RequestInfo requestInfo = getRequestInfo(request); + + logger.warn("Response failure: {} {} ({}) [{}] ", + requestInfo.method, + requestInfo.path, + requestInfo.userAgent, + status, + failure); + } + + @Override + public void onComplete(Request request, int status, HttpFields headers, Throwable failure) { + + super.onComplete(request, status, headers, failure); + + if (failure != null) { + onResponseFailure(request, status, failure); + } + + final RequestInfo requestInfo = getRequestInfo(request); + + @Nullable final UserAgent userAgent; + { + UserAgent parsedUserAgent; + + try { + parsedUserAgent = UserAgentUtil.parseUserAgentString(requestInfo.userAgent); + } catch (final UnrecognizedUserAgentException e) { + parsedUserAgent = null; + } + + userAgent = parsedUserAgent; + } + final Tag platformTag = UserAgentTagUtil.getPlatformTag(userAgent); + + + final List tags = new ArrayList<>(5); + tags.add(Tag.of(PATH_TAG, requestInfo.path)); + tags.add(Tag.of(METHOD_TAG, requestInfo.method)); + tags.add(Tag.of(STATUS_CODE_TAG, String.valueOf(status))); + tags.add(Tag.of(TRAFFIC_SOURCE_TAG, TrafficSource.HTTP.name().toLowerCase())); + tags.add(platformTag); + + final Optional maybeClientVersionTag = + UserAgentTagUtil.getClientVersionTag(userAgent, clientReleaseManager); + + maybeClientVersionTag.ifPresent(tags::add); + + meterRegistry.counter(REQUEST_COUNTER_NAME, tags).increment(); + meterRegistry.counter(RESPONSE_BYTES_COUNTER_NAME, tags).increment(requestInfo.responseBytes); + meterRegistry.counter(REQUEST_BYTES_COUNTER_NAME, tags).increment(requestInfo.requestBytes); + + maybeClientVersionTag.ifPresent(clientVersionTag -> meterRegistry.counter(REQUESTS_BY_VERSION_COUNTER_NAME, + Tags.of(clientVersionTag, platformTag)) + .increment()); + } + + @Override + protected void onRequestRead(final Request request, final Content.Chunk chunk) { + super.onRequestRead(request, chunk); + if (chunk != null) { + getRequestInfo(request).requestBytes += chunk.remaining(); + } + } + + @Override + protected void onResponseWrite(final Request request, final boolean last, final ByteBuffer content) { + super.onResponseWrite(request, last, content); + if (content != null) { + getRequestInfo(request).responseBytes += content.remaining(); + } + } + + private RequestInfo getRequestInfo(Request request) { + Object obj = request.getAttribute(REQUEST_INFO_PROPERTY_NAME); + if (obj != null && obj instanceof RequestInfo requestInfo) { + return requestInfo; + } + + // Our ContainerResponseFilter has not run yet. It should eventually run, and will override the path we set here. + // It may not run if this is a websocket upgrade request, a request handled by jetty directly, or a higher priority + // filter aborted the request by throwing an exception, in which case we'll use this path. To avoid giving every + // incorrect path a unique tag we check against a configured list of paths that we know would skip the filter. + final RequestInfo newInfo = new RequestInfo( + Optional.ofNullable(request.getHttpURI().getPath()).filter(servletPaths::contains).orElse("unknown"), + normalizeMethod(request.getMethod()), + request.getHeaders().get(HttpHeaders.USER_AGENT)); + + request.setAttribute(REQUEST_INFO_PROPERTY_NAME, newInfo); + return newInfo; + } + + @VisibleForTesting + static class RequestInfo { + + private String path; + private final String method; + private final @Nullable String userAgent; + + private long requestBytes; + private long responseBytes; + + RequestInfo(@NotNull String path, @NotNull String method, @Nullable String userAgent) { + this.path = path; + this.method = method; + this.userAgent = userAgent; + this.requestBytes = 0; + this.responseBytes = 0; + } + + @Override + public boolean equals(final Object o) { + if (o == null || getClass() != o.getClass()) { + return false; + } + RequestInfo that = (RequestInfo) o; + return requestBytes == that.requestBytes && responseBytes == that.responseBytes && Objects.equals(path, that.path) + && Objects.equals(method, that.method) && Objects.equals(userAgent, that.userAgent); + } + + @Override + public int hashCode() { + return Objects.hash(path, method, userAgent, requestBytes, responseBytes); + } + } + + @VisibleForTesting + static class SetInfoRequestFilter implements ContainerResponseFilter { + + @Override + public void filter(final ContainerRequestContext requestContext, final ContainerResponseContext responseContext) { + // Construct the templated URI path. If no matching path is found, this will be "" + final String path = UriInfoUtil.getPathTemplate((ExtendedUriInfo) requestContext.getUriInfo()); + final Object obj = requestContext.getProperty(REQUEST_INFO_PROPERTY_NAME); + if (obj != null && obj instanceof RequestInfo requestInfo) { + requestInfo.path = path; + } else { + requestContext.setProperty(REQUEST_INFO_PROPERTY_NAME, + new RequestInfo(path, requestContext.getMethod(), requestContext.getHeaderString(HttpHeaders.USER_AGENT))); + } + } + } + + static String normalizeMethod(@Nullable final String method) { + if (StringUtils.isBlank(method)) { + return "unknown"; + } + + return EXPECTED_HTTP_METHODS.contains(method.toUpperCase(Locale.ROOT)) ? method : "unknown"; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListener.java index a4a85c1d5..5f5897a5a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListener.java @@ -31,7 +31,7 @@ import org.whispersystems.websocket.WebSocketResourceProvider; /** * Gathers and reports request-level metrics for WebSocket traffic only. - * For HTTP traffic, use {@link MetricsHttpChannelListener}. + * For HTTP traffic, use {@link MetricsHttpEventHandler}. */ public class MetricsRequestEventListener implements RequestEventListener { @@ -69,7 +69,7 @@ public class MetricsRequestEventListener implements RequestEventListener { this(trafficSource, Metrics.globalRegistry, clientReleaseManager); if (trafficSource == TrafficSource.HTTP) { - logger.warn("Use {} for HTTP traffic", MetricsHttpChannelListener.class.getName()); + logger.warn("Use {} for HTTP traffic", MetricsHttpEventHandler.class.getName()); } } @@ -107,7 +107,7 @@ public class MetricsRequestEventListener implements RequestEventListener { final List tags = new ArrayList<>(); tags.add(Tag.of(PATH_TAG, UriInfoUtil.getPathTemplate(event.getUriInfo()))); - tags.add(Tag.of(METHOD_TAG, MetricsHttpChannelListener.normalizeMethod(event.getContainerRequest().getMethod()))); + tags.add(Tag.of(METHOD_TAG, MetricsHttpEventHandler.normalizeMethod(event.getContainerRequest().getMethod()))); tags.add(Tag.of(STATUS_CODE_TAG, String.valueOf(Optional .ofNullable(event.getContainerResponse()) .map(ContainerResponse::getStatus) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/TlsCertificateExpirationUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/TlsCertificateExpirationUtil.java index 06f786649..04b29941f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/TlsCertificateExpirationUtil.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/TlsCertificateExpirationUtil.java @@ -21,7 +21,7 @@ import java.util.HashMap; import java.util.Iterator; import java.util.Map; import java.util.Optional; -import org.eclipse.jetty.util.resource.Resource; +import org.eclipse.jetty.util.resource.PathResourceFactory; import org.eclipse.jetty.util.security.CertificateUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -37,7 +37,7 @@ public class TlsCertificateExpirationUtil { final KeyStore keyStore; try { - keyStore = CertificateUtils.getKeyStore(Resource.newResource(keyStorePath), keyStoreType, keyStoreProvider, + keyStore = CertificateUtils.getKeyStore(new PathResourceFactory().newResource(keyStorePath), keyStoreType, keyStoreProvider, keyStorePassword); } catch (Exception e) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClient.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClient.java index d57ab65f3..302bb8ce4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClient.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClient.java @@ -3,6 +3,7 @@ package org.whispersystems.textsecuregcm.redis; import com.google.common.annotations.VisibleForTesting; import io.github.resilience4j.circuitbreaker.CircuitBreaker; import io.lettuce.core.ClientOptions; +import io.lettuce.core.MaintNotificationsConfig; import io.lettuce.core.RedisClient; import io.lettuce.core.RedisException; import io.lettuce.core.RedisURI; @@ -57,20 +58,10 @@ public class FaultTolerantRedisClient { this.name = name; - // Lettuce will issue a CLIENT SETINFO command unconditionally if these fields are set (and they are by default), - // which can generate a bunch of spurious warnings in versions of Redis before 7.2.0. - // - // See: - // - // - https://github.com/redis/lettuce/pull/2823 - // - https://github.com/redis/lettuce/issues/2817 - redisUri.setClientName(null); - redisUri.setLibraryName(null); - redisUri.setLibraryVersion(null); - this.redisClient = RedisClient.create(clientResourcesBuilder.build(), redisUri); final ClientOptions.Builder clientOptionsBuilder = ClientOptions.builder() .disconnectedBehavior(ClientOptions.DisconnectedBehavior.REJECT_COMMANDS) + .maintNotificationsConfig(MaintNotificationsConfig.disabled()) // for asynchronous commands .timeoutOptions(TimeoutOptions.builder() .fixedTimeout(commandTimeout) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClusterClient.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClusterClient.java index e0fba6ef7..ae277b6cf 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClusterClient.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClusterClient.java @@ -9,6 +9,7 @@ import io.github.resilience4j.core.IntervalFunction; import io.github.resilience4j.retry.Retry; import io.github.resilience4j.retry.RetryConfig; import io.lettuce.core.ClientOptions; +import io.lettuce.core.MaintNotificationsConfig; import io.lettuce.core.RedisException; import io.lettuce.core.RedisURI; import io.lettuce.core.TimeoutOptions; @@ -69,28 +70,15 @@ public class FaultTolerantRedisClusterClient { this.name = name; - // Lettuce will issue a CLIENT SETINFO command unconditionally if these fields are set (and they are by default), - // which can generate a bunch of spurious warnings in versions of Redis before 7.2.0. - // - // See: - // - // - https://github.com/redis/lettuce/pull/2823 - // - https://github.com/redis/lettuce/issues/2817 - redisUris.forEach(redisUri -> { - redisUri.setClientName(null); - redisUri.setLibraryName(null); - redisUri.setLibraryVersion(null); - }); - final LettuceShardCircuitBreaker lettuceShardCircuitBreaker = new LettuceShardCircuitBreaker(name, circuitBreakerConfigurationName); this.clusterClient = RedisClusterClient.create( - clientResourcesBuilder.nettyCustomizer(lettuceShardCircuitBreaker). - build(), + clientResourcesBuilder.nettyCustomizer(lettuceShardCircuitBreaker) + .build(), redisUris); - final ClusterClientOptions.Builder clusterClientOptionsBuilder = ClusterClientOptions.builder() + final ClusterClientOptions.Builder clusterClientOptionsBuilder = (ClusterClientOptions.Builder) ClusterClientOptions.builder() .disconnectedBehavior(ClientOptions.DisconnectedBehavior.REJECT_COMMANDS) .validateClusterNodeMembership(false) .topologyRefreshOptions(ClusterTopologyRefreshOptions.builder() @@ -100,7 +88,8 @@ public class FaultTolerantRedisClusterClient { .timeoutOptions(TimeoutOptions.builder() .fixedTimeout(commandTimeout) .build()) - .publishOnScheduler(true); + .publishOnScheduler(true) + .maintNotificationsConfig(MaintNotificationsConfig.disabled()); NettyUtil.setSocketTimeoutsIfApplicable(clusterClientOptionsBuilder); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/JmxDumper.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/JmxDumper.java deleted file mode 100644 index 036917805..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/JmxDumper.java +++ /dev/null @@ -1,33 +0,0 @@ -package org.whispersystems.textsecuregcm.util; - -import io.dropwizard.lifecycle.Managed; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import javax.management.MBeanServer; -import java.lang.management.ManagementFactory; -import java.util.concurrent.ScheduledExecutorService; - -/* - * Copyright 2025 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -public class JmxDumper implements Managed { - private static final Logger log = LoggerFactory.getLogger(JmxDumper.class); - - - private final ScheduledExecutorService executor; - - public JmxDumper(final ScheduledExecutorService executor) { - this.executor = executor; - } - - @Override - public void start() throws Exception { -// executor.schedule() - } - - private void dump() { - MBeanServer mbs = ManagementFactory.getPlatformMBeanServer(); - - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java index c88fa5e29..4ba267dfc 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java @@ -8,14 +8,14 @@ package org.whispersystems.textsecuregcm.websocket; import static org.whispersystems.textsecuregcm.util.HeaderUtils.basicCredentialsFromAuthHeader; import com.google.common.net.HttpHeaders; -import javax.annotation.Nullable; import io.dropwizard.auth.basic.BasicCredentials; -import org.eclipse.jetty.websocket.api.UpgradeRequest; +import java.util.Optional; +import javax.annotation.Nullable; +import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.websocket.auth.InvalidCredentialsException; import org.whispersystems.websocket.auth.WebSocketAuthenticator; -import java.util.Optional; public class WebSocketAccountAuthenticator implements WebSocketAuthenticator { @@ -27,7 +27,7 @@ public class WebSocketAccountAuthenticator implements WebSocketAuthenticator authenticate(final UpgradeRequest request) + public Optional authenticate(final JettyServerUpgradeRequest request) throws InvalidCredentialsException { @Nullable final String authHeader = request.getHeader(HttpHeaders.AUTHORIZATION); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/BufferingInterceptorIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/BufferingInterceptorIntegrationTest.java index 4183026b3..cce4dedbd 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/BufferingInterceptorIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/BufferingInterceptorIntegrationTest.java @@ -15,7 +15,6 @@ import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.core.Response; import org.apache.commons.lang3.RandomStringUtils; -import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; import org.glassfish.jersey.server.ManagedAsync; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -35,7 +34,6 @@ public class BufferingInterceptorIntegrationTest { environment.jersey().register(testController); environment.jersey().register(new BufferingInterceptor()); environment.jersey().register(new VirtualExecutorServiceProvider("virtual-thread-", 10)); - JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/ProvisioningTimeoutIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/ProvisioningTimeoutIntegrationTest.java index ef37d6b64..5e647359a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/ProvisioningTimeoutIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/ProvisioningTimeoutIntegrationTest.java @@ -18,28 +18,29 @@ import io.dropwizard.core.Configuration; import io.dropwizard.core.setup.Environment; import io.dropwizard.testing.junit5.DropwizardAppExtension; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; -import jakarta.servlet.DispatcherType; -import jakarta.servlet.ServletRegistration; import java.io.IOException; import java.net.URI; +import java.nio.ByteBuffer; import java.time.Duration; -import java.util.EnumSet; import java.util.Objects; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; +import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer; +import org.eclipse.jetty.websocket.api.Callback; import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; import org.eclipse.jetty.websocket.client.WebSocketClient; -import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.ArgumentCaptor; import org.whispersystems.textsecuregcm.asn.AsnInfoProvider; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.filters.PriorityFilter; import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.push.ProvisioningManager; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; @@ -52,6 +53,7 @@ import org.whispersystems.websocket.messages.WebSocketMessage; import org.whispersystems.websocket.setup.WebSocketEnvironment; @ExtendWith(DropwizardExtensionsSupport.class) +@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) public class ProvisioningTimeoutIntegrationTest { private static final DropwizardAppExtension DROPWIZARD_APP_EXTENSION = @@ -77,9 +79,9 @@ public class ProvisioningTimeoutIntegrationTest { CompletableFuture provisioningAddressFuture = new CompletableFuture<>(); @Override - public void onWebSocketBinary(final byte[] payload, final int offset, final int length) { + public void onWebSocketBinary(final ByteBuffer payload, final Callback callback) { try { - WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length); + WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload); if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.REQUEST_MESSAGE && webSocketMessage.getRequestMessage().getPath().equals("/v1/address")) { MessageProtos.ProvisioningAddress provisioningAddress = @@ -92,7 +94,7 @@ public class ProvisioningTimeoutIntegrationTest { } catch (InvalidProtocolBufferException e) { throw new RuntimeException(e); } - super.onWebSocketBinary(payload, offset, length); + super.onWebSocketBinary(payload, callback); } } @@ -106,21 +108,17 @@ public class ProvisioningTimeoutIntegrationTest { final WebSocketEnvironment webSocketEnvironment = new WebSocketEnvironment<>(environment, webSocketConfiguration); - environment.servlets() - .addFilter("RemoteAddressFilter", new RemoteAddressFilter()) - .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*"); webSocketEnvironment.setConnectListener( new ProvisioningConnectListener(mock(ProvisioningManager.class), () -> mock(AsnInfoProvider.class), mock(ClientReleaseManager.class), scheduler, Duration.ofSeconds(5))); final WebSocketResourceProviderFactory webSocketServlet = new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedDevice.class, - webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME); + REMOTE_ADDRESS_ATTRIBUTE_NAME); - JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null); - final ServletRegistration.Dynamic websocketServlet = environment.servlets() - .addServlet("WebSocket", webSocketServlet); - websocketServlet.addMapping("/websocket"); - websocketServlet.setAsyncSupported(true); + JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), (servletContext, container) -> { + container.addMapping("/websocket", webSocketServlet); + PriorityFilter.ensureFilter(servletContext, new RemoteAddressFilter()); + }); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/WebsocketResourceProviderIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/WebsocketResourceProviderIntegrationTest.java index d74e599b0..606ca0739 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/WebsocketResourceProviderIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/WebsocketResourceProviderIntegrationTest.java @@ -9,8 +9,6 @@ import io.dropwizard.core.Configuration; import io.dropwizard.core.setup.Environment; import io.dropwizard.testing.junit5.DropwizardAppExtension; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; -import jakarta.servlet.DispatcherType; -import jakarta.servlet.ServletRegistration; import jakarta.ws.rs.GET; import jakarta.ws.rs.PUT; import jakarta.ws.rs.Path; @@ -20,19 +18,20 @@ import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; import java.io.IOException; import java.net.URI; -import java.util.EnumSet; import java.util.Optional; import org.apache.commons.lang3.RandomStringUtils; +import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer; import org.eclipse.jetty.websocket.client.WebSocketClient; -import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; import org.glassfish.jersey.server.ManagedAsync; import org.glassfish.jersey.server.ServerProperties; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.filters.PriorityFilter; import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener; import org.whispersystems.websocket.WebSocketResourceProviderFactory; @@ -41,6 +40,7 @@ import org.whispersystems.websocket.messages.WebSocketResponseMessage; import org.whispersystems.websocket.setup.WebSocketEnvironment; @ExtendWith(DropwizardExtensionsSupport.class) +@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) public class WebsocketResourceProviderIntegrationTest { private static final DropwizardAppExtension DROPWIZARD_APP_EXTENSION = new DropwizardAppExtension<>(TestApplication.class); @@ -72,9 +72,6 @@ public class WebsocketResourceProviderIntegrationTest { new WebSocketEnvironment<>(environment, webSocketConfiguration); environment.jersey().register(testController); - environment.servlets() - .addFilter("RemoteAddressFilter", new RemoteAddressFilter()) - .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*"); webSocketEnvironment.jersey().register(testController); webSocketEnvironment.jersey().register(new RemoteAddressFilter()); webSocketEnvironment.setAuthenticator(upgradeRequest -> Optional.of(mock(AuthenticatedDevice.class))); @@ -85,15 +82,13 @@ public class WebsocketResourceProviderIntegrationTest { final WebSocketResourceProviderFactory webSocketServlet = new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedDevice.class, - webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME); + REMOTE_ADDRESS_ATTRIBUTE_NAME); - JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null); + JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), (servletContext, container) -> { + container.addMapping("/websocket", webSocketServlet); + PriorityFilter.ensureFilter(servletContext, new RemoteAddressFilter()); + }); - final ServletRegistration.Dynamic websocketServlet = - environment.servlets().addServlet("WebSocket", webSocketServlet); - - websocketServlet.addMapping("/websocket"); - websocketServlet.setAsyncSupported(true); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/WhisperServerServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/WhisperServerServiceTest.java index 910017759..2a36b2d0d 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/WhisperServerServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/WhisperServerServiceTest.java @@ -13,35 +13,47 @@ import io.dropwizard.testing.ConfigOverride; import io.dropwizard.testing.junit5.DropwizardAppExtension; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.util.Resources; +import io.grpc.ManagedChannel; +import io.grpc.Metadata; +import io.grpc.netty.NettyChannelBuilder; +import io.grpc.stub.MetadataUtils; import jakarta.ws.rs.client.Client; import jakarta.ws.rs.core.Response; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.net.ServerSocket; import java.net.URI; import java.util.List; import java.util.Map; -import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.Semaphore; -import java.util.concurrent.ThreadFactory; -import java.util.concurrent.atomic.AtomicInteger; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import org.eclipse.jetty.client.HttpClient; +import org.eclipse.jetty.http2.client.HTTP2Client; +import org.eclipse.jetty.http2.client.transport.HttpClientTransportOverHTTP2; +import org.eclipse.jetty.io.ClientConnector; import org.eclipse.jetty.util.component.LifeCycle; import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.StatusCode; import org.eclipse.jetty.websocket.client.WebSocketClient; -import org.jetbrains.annotations.NotNull; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.RegisterExtension; -import org.whispersystems.textsecuregcm.configuration.OpenTelemetryConfiguration; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.signal.chat.account.AccountsAnonymousGrpc; +import org.signal.chat.account.CheckAccountExistenceRequest; +import org.signal.chat.account.CheckAccountExistenceResponse; +import org.signal.chat.common.IdentityType; +import org.signal.chat.common.ServiceIdentifier; import org.whispersystems.textsecuregcm.metrics.NoopAwsSdkMetricPublisher; import org.whispersystems.textsecuregcm.storage.DynamoDbExtension; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema; import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener; import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.HeaderUtils; -import org.whispersystems.textsecuregcm.util.Util; +import org.whispersystems.textsecuregcm.util.UUIDUtil; import org.whispersystems.websocket.messages.WebSocketResponseMessage; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; @@ -58,13 +70,16 @@ class WhisperServerServiceTest { System.setProperty("secrets.bundle.filename", Resources.getResource("config/test-secrets-bundle.yml").getPath()); } + private static final int OMNIBUS_PORT = findAvailablePort(); - private static final WebSocketClient webSocketClient = new WebSocketClient(); + private static WebSocketClient webSocketClient; + private static WebSocketClient h2WebSocketClient; private static final DropwizardAppExtension EXTENSION = new DropwizardAppExtension<>( WhisperServerService.class, Resources.getResource("config/test.yml").getPath(), // Tables will be created by the local DynamoDbExtension - ConfigOverride.config("dynamoDbClient.initTables", "false")); + ConfigOverride.config("dynamoDbClient.initTables", "false"), + ConfigOverride.config("grpc.port", String.valueOf(OMNIBUS_PORT))); @RegisterExtension public static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(DynamoDbExtensionSchema.Tables.values()); @@ -76,7 +91,13 @@ class WhisperServerServiceTest { @BeforeAll static void setUp() throws Exception { + final ClientConnector clientConnector = new ClientConnector(); + final HTTP2Client http2Client = new HTTP2Client(clientConnector); + final HttpClient httpClient = new HttpClient(new HttpClientTransportOverHTTP2(http2Client)); + h2WebSocketClient = new WebSocketClient(httpClient); + webSocketClient = new WebSocketClient(); webSocketClient.start(); + h2WebSocketClient.start(); } @Test @@ -100,8 +121,9 @@ class WhisperServerServiceTest { assertEquals(200, healthCheck.getStatus()); } - @Test - void websocket() throws Exception { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void websocket(final boolean useH2) throws Exception { // test unauthenticated websocket final long start = System.currentTimeMillis(); @@ -117,7 +139,7 @@ class WhisperServerServiceTest { }); // Session is Closeable, but we intentionally keep it open so that we can confirm the container Lifecycle behavior - final Session session = webSocketClient.connect(testWebsocketListener, + final Session session = (useH2 ? h2WebSocketClient : webSocketClient).connect(testWebsocketListener, URI.create(String.format("ws://localhost:%d/v1/websocket/", EXTENSION.getLocalPort()))) .join(); final long sessionTimestamp = Long.parseLong(session.getUpgradeResponse().getHeader(HeaderUtils.TIMESTAMP_HEADER)); @@ -185,6 +207,51 @@ class WhisperServerServiceTest { .build()); } + + @Test + void omnibusWebsocket() throws Exception { + final HTTP2Client http2Client = new HTTP2Client(new ClientConnector()); + final WebSocketClient h2WebSocketClient = + new WebSocketClient(new HttpClient(new HttpClientTransportOverHTTP2(http2Client))); + h2WebSocketClient.start(); + + final TestWebsocketListener testWebsocketListener = new TestWebsocketListener(); + final Session session = h2WebSocketClient.connect(testWebsocketListener, + URI.create(String.format("ws://localhost:%d/v1/websocket/", OMNIBUS_PORT))) + .join(); + final WebSocketResponseMessage keepAlive = testWebsocketListener.doGet("/v1/keepalive").join(); + assertEquals(200, keepAlive.getStatus()); + final WebSocketResponseMessage whoami = testWebsocketListener.doGet("/v1/accounts/whoami").join(); + assertEquals(401, whoami.getStatus()); + session.close(); + h2WebSocketClient.stop(); + } + + @Test + void omnibusGrpc() throws Exception { + final ManagedChannel channel = NettyChannelBuilder.forAddress("localhost", OMNIBUS_PORT) + .usePlaintext() + .build(); + + final Metadata metadata = new Metadata(); + metadata.put(Metadata.Key.of("X-Forwarded-For", Metadata.ASCII_STRING_MARSHALLER), "127.0.0.1"); + + final AccountsAnonymousGrpc.AccountsAnonymousBlockingStub stub = AccountsAnonymousGrpc + .newBlockingStub(channel) + .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata)); + final CheckAccountExistenceResponse response = stub.checkAccountExistence( + CheckAccountExistenceRequest.newBuilder() + .setServiceIdentifier(ServiceIdentifier.newBuilder() + .setIdentityType(IdentityType.IDENTITY_TYPE_ACI) + .setUuid(UUIDUtil.toByteString(UUID.randomUUID())) + .build()) + .build()); + + assertFalse(response.getAccountExists()); + channel.shutdownNow(); + channel.awaitTermination(1, TimeUnit.SECONDS); + } + private static DynamoDbClient getDynamoDbClient() { final AwsCredentialsProvider awsCredentialsProvider = EXTENSION.getConfiguration().getAwsCredentialsConfiguration() .build(); @@ -193,4 +260,12 @@ class WhisperServerServiceTest { .buildSyncClient(awsCredentialsProvider, new NoopAwsSdkMetricPublisher()); } + private static int findAvailablePort() { + try (ServerSocket socket = new ServerSocket(0)) { + return socket.getLocalPort(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilterTest.java index 28ea6d113..e7417ebc1 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilterTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilterTest.java @@ -15,8 +15,8 @@ import java.util.List; import java.util.Optional; import java.util.UUID; import javax.annotation.Nullable; -import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest; -import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse; +import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest; +import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeResponse; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/PhoneVerificationTokenManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/PhoneVerificationTokenManagerTest.java index 4c34b8313..d21cc4cda 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/PhoneVerificationTokenManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/PhoneVerificationTokenManagerTest.java @@ -210,8 +210,7 @@ class PhoneVerificationTokenManagerTest { @ParameterizedTest @MethodSource - void verifyRecoveryPasswordManagerException(final Throwable recoveryPasswordManagerException) - throws ExecutionException, InterruptedException, TimeoutException { + void verifyRecoveryPasswordManagerException(final Throwable recoveryPasswordManagerException) { final ContainerRequestContext containerRequestContext = mock(ContainerRequestContext.class); final byte[] recoveryPassword = TestRandomUtil.nextBytes(16); @@ -219,11 +218,8 @@ class PhoneVerificationTokenManagerTest { when(registrationRecoveryChecker.checkRegistrationRecoveryAttempt(containerRequestContext, PHONE_NUMBER)) .thenReturn(true); - @SuppressWarnings("unchecked") final CompletableFuture mockFuture = mock(CompletableFuture.class); - when(mockFuture.get(anyLong(), any())).thenThrow(recoveryPasswordManagerException); - when(registrationRecoveryPasswordsManager.verify(PHONE_NUMBER_IDENTIFIER, recoveryPassword)) - .thenReturn(mockFuture); + .thenReturn(CompletableFuture.failedFuture(recoveryPasswordManagerException)); assertThrows(ServerErrorException.class, () -> phoneVerificationTokenManager.verify(containerRequestContext, PHONE_NUMBER, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilterIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilterIntegrationTest.java index 96ae5283f..edc6c576e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilterIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilterIntegrationTest.java @@ -13,20 +13,17 @@ import io.dropwizard.core.Configuration; import io.dropwizard.core.setup.Environment; import io.dropwizard.testing.junit5.DropwizardAppExtension; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; -import jakarta.servlet.DispatcherType; import jakarta.ws.rs.GET; import jakarta.ws.rs.Path; import jakarta.ws.rs.client.Client; import jakarta.ws.rs.container.ContainerRequestContext; import jakarta.ws.rs.core.Context; -import java.io.IOException; import java.net.InetAddress; import java.net.URI; import java.nio.ByteBuffer; import java.security.Principal; import java.time.Duration; import java.util.Arrays; -import java.util.EnumSet; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -35,14 +32,15 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import javax.security.auth.Subject; +import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer; import org.eclipse.jetty.util.HostPort; +import org.eclipse.jetty.websocket.api.Callback; import org.eclipse.jetty.websocket.api.Session; -import org.eclipse.jetty.websocket.api.WebSocketListener; import org.eclipse.jetty.websocket.client.WebSocketClient; -import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -55,6 +53,7 @@ import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFa import org.whispersystems.websocket.setup.WebSocketEnvironment; @ExtendWith(DropwizardExtensionsSupport.class) +@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) class RemoteAddressFilterIntegrationTest { private static final String WEBSOCKET_PREFIX = "/websocket"; @@ -131,7 +130,7 @@ class RemoteAddressFilterIntegrationTest { } } - private static class ClientEndpoint implements WebSocketListener { + public static class ClientEndpoint implements Session.Listener.AutoDemanding { private final String requestPath; private final CompletableFuture responseFuture; @@ -145,22 +144,19 @@ class RemoteAddressFilterIntegrationTest { } @Override - public void onWebSocketConnect(final Session session) { + public void onWebSocketOpen(final Session session) { final byte[] requestBytes = messageFactory.createRequest(Optional.of(1L), "GET", requestPath, List.of("Accept: application/json"), Optional.empty()).toByteArray(); - try { - session.getRemote().sendBytes(ByteBuffer.wrap(requestBytes)); - } catch (IOException e) { - throw new RuntimeException(e); - } + + session.sendBinary(ByteBuffer.wrap(requestBytes), Callback.NOOP); } @Override - public void onWebSocketBinary(final byte[] payload, final int offset, final int length) { + public void onWebSocketBinary(final ByteBuffer payload, final Callback callback) { try { - WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length); + WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload); if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.RESPONSE_MESSAGE) { assert 200 == webSocketMessage.getResponseMessage().getStatus(); @@ -206,10 +202,6 @@ class RemoteAddressFilterIntegrationTest { public void run(final Configuration configuration, final Environment environment) throws Exception { - environment.servlets().addFilter("RemoteAddressFilterRemoteAddress", new RemoteAddressFilter()) - .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, REMOTE_ADDRESS_PATH, - WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH); - environment.jersey().register(new TestRemoteAddressController()); // WebSocket set up @@ -220,15 +212,14 @@ class RemoteAddressFilterIntegrationTest { webSocketEnvironment.jersey().register(new TestWebSocketController()); - JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null); - WebSocketResourceProviderFactory webSocketServlet = new WebSocketResourceProviderFactory<>( - webSocketEnvironment, TestPrincipal.class, webSocketConfiguration, + webSocketEnvironment, TestPrincipal.class, RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME); - environment.servlets().addServlet("WebSocketRemoteAddress", webSocketServlet) - .addMapping(WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH); - + JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), (servletContext, container) -> { + container.addMapping(WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH, webSocketServlet); + PriorityFilter.ensureFilter(servletContext, new RemoteAddressFilter()); + }); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/filters/StripContentLengthOnConnectFilterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/filters/StripContentLengthOnConnectFilterTest.java new file mode 100644 index 000000000..a124653db --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/filters/StripContentLengthOnConnectFilterTest.java @@ -0,0 +1,136 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.filters; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; + +import io.dropwizard.core.Application; +import io.dropwizard.core.Configuration; +import io.dropwizard.core.setup.Environment; +import io.dropwizard.testing.ConfigOverride; +import io.dropwizard.testing.junit5.DropwizardAppExtension; +import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; +import jakarta.servlet.DispatcherType; +import jakarta.servlet.Filter; +import jakarta.servlet.ServletContext; +import java.net.InetSocketAddress; +import java.util.EnumSet; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import org.eclipse.jetty.ee10.servlet.FilterHolder; +import org.eclipse.jetty.ee10.servlet.FilterMapping; +import org.eclipse.jetty.ee10.servlet.ServletContextHandler; +import org.eclipse.jetty.ee10.servlet.ServletHandler; +import org.eclipse.jetty.ee10.websocket.server.JettyWebSocketCreator; +import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer; +import org.eclipse.jetty.http.HostPortHttpField; +import org.eclipse.jetty.http.HttpFields; +import org.eclipse.jetty.http.HttpScheme; +import org.eclipse.jetty.http.MetaData; +import org.eclipse.jetty.http2.api.Session; +import org.eclipse.jetty.http2.api.Stream; +import org.eclipse.jetty.http2.client.HTTP2Client; +import org.eclipse.jetty.http2.frames.HeadersFrame; +import org.eclipse.jetty.server.handler.ContextHandler; +import org.eclipse.jetty.util.Jetty; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(DropwizardExtensionsSupport.class) +class StripContentLengthOnConnectFilterTest { + + private static final String WEBSOCKET_WITH_FILTER_PATH = "/websocket/filtered"; + private static final String WEBSOCKET_WITHOUT_FILTER_PATH = "/websocket/unfiltered"; + + private static final DropwizardAppExtension EXTENSION = + new DropwizardAppExtension<>(TestApplicationWithFilter.class, null, + ConfigOverride.config("server.applicationConnectors[0].type", "h2c"), + ConfigOverride.config("server.applicationConnectors[0].port", "0")); + + + @Test + void contentLengthIsStrippedOnUpgrade() throws Exception { + final HttpFields fields = sendUpgradeRequest(WEBSOCKET_WITH_FILTER_PATH); + assertNull(fields.getField("content-length")); + } + + @Test + void contentLengthIncorrectlyIncludedOnUpgrade() throws Exception { + final HttpFields fields = sendUpgradeRequest(WEBSOCKET_WITHOUT_FILTER_PATH); + assertNotNull(fields.getField("content-length"), """ + If this fails, our jetty version no longer includes errant content-lengths on H2 connect responses. + StripContentLengthOnConnectFilter can now be removed. + """); + } + + @Test + void versionCheck() { + assertEquals("12.1.5", Jetty.VERSION, "This class can be removed with https://github.com/jetty/jetty.project/issues/15074, likely 12.1.10"); + } + + private static HttpFields sendUpgradeRequest(final String path) throws Exception { + final int port = EXTENSION.getLocalPort(); + try (final HTTP2Client client = new HTTP2Client()) { + client.start(); + final Session session = + client.connect(new InetSocketAddress("localhost", port), new Session.Listener() {}).join(); + final HttpFields requestFields = HttpFields.build().put("sec-websocket-version", "13"); + final HostPortHttpField hostPort = new HostPortHttpField("localhost:" + port); + final MetaData.ConnectRequest connect = + new MetaData.ConnectRequest(HttpScheme.HTTP, hostPort, path, requestFields, "websocket"); + final HeadersFrame headersFrame = new HeadersFrame(connect, null, false); + + final CompletableFuture responseFuture = new CompletableFuture<>(); + Stream.Listener streamListener = new Stream.Listener() { + @Override + public void onHeaders(Stream stream, HeadersFrame frame) { + if (frame.getMetaData().isResponse()) { + responseFuture.complete((MetaData.Response) frame.getMetaData()); + } + } + }; + session.newStream(headersFrame, streamListener).get(5, TimeUnit.SECONDS); + final MetaData.Response response = responseFuture.get(5, TimeUnit.SECONDS); + return response.getHttpFields(); + } + } + + private static class NoopWebSocket implements org.eclipse.jetty.websocket.api.Session.Listener.AutoDemanding {} + + public static class TestApplicationWithFilter extends Application { + + @Override + public void run(final Configuration configuration, final Environment environment) throws Exception { + final JettyWebSocketCreator jwsc = (_, _) -> new NoopWebSocket(); + JettyWebSocketServletContainerInitializer.configure( + environment.getApplicationContext(), + (servletContext, container) -> { + + container.addMapping(WEBSOCKET_WITH_FILTER_PATH, jwsc); + container.addMapping(WEBSOCKET_WITHOUT_FILTER_PATH, jwsc); + ensureFilter(servletContext, WEBSOCKET_WITH_FILTER_PATH, StripContentLengthOnConnectFilter.class); + }); + } + } + + private static void ensureFilter( + final ServletContext servletContext, + final String pathSpec, + final Class filterClass) { + + final ContextHandler contextHandler = ServletContextHandler.getServletContextHandler(servletContext); + final ServletHandler servletHandler = contextHandler.getDescendant(ServletHandler.class); + final FilterHolder holder = new FilterHolder(filterClass); + final FilterMapping mapping = new FilterMapping(); + mapping.setFilterName(holder.getName()); + mapping.setPathSpec(pathSpec); + mapping.setDispatcherTypes(EnumSet.of(DispatcherType.REQUEST)); + servletHandler.prependFilter(holder); + servletHandler.prependFilterMapping(mapping); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractLeakDetectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractLeakDetectionTest.java new file mode 100644 index 000000000..e219e4f0f --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractLeakDetectionTest.java @@ -0,0 +1,25 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX License Identifier: AGPL 3.0 only + */ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.netty.util.ResourceLeakDetector; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; + +public abstract class AbstractLeakDetectionTest { + + private static ResourceLeakDetector.Level originalResourceLeakDetectorLevel; + + @BeforeAll + static void setLeakDetectionLevel() { + originalResourceLeakDetectorLevel = ResourceLeakDetector.getLevel(); + ResourceLeakDetector.setLevel(ResourceLeakDetector.Level.PARANOID); + } + + @AfterAll + static void restoreLeakDetectionLevel() { + ResourceLeakDetector.setLevel(originalResourceLeakDetectorLevel); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/H2FrameProxyHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/H2FrameProxyHandlerTest.java new file mode 100644 index 000000000..93ba8f82f --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/H2FrameProxyHandlerTest.java @@ -0,0 +1,35 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.netty.channel.WriteBufferWaterMark; +import io.netty.channel.embedded.EmbeddedChannel; +import java.nio.charset.StandardCharsets; +import org.junit.jupiter.api.Test; + +class H2FrameProxyHandlerTest extends AbstractLeakDetectionTest { + @Test + void proxyWritabilityChanged() { + final EmbeddedChannel target = new EmbeddedChannel(); + final EmbeddedChannel source = new EmbeddedChannel(new H2FrameProxyHandler(target, "test")); + + // Set a tiny watermark to guarantee an unflushed write sets to unwritable, and then buffer some data + final byte[] bufferedData = "8 bytes!".getBytes(StandardCharsets.UTF_8); + target.config().setWriteBufferWaterMark(new WriteBufferWaterMark(4, 8)); + target.write(bufferedData); + + assertNull(target.readOutbound(), "nothing should be written without a flush"); + assertFalse(target.isWritable(), "target should be unwritable because we've buffered more than the high watermark"); + assertFalse(source.config().isAutoRead(), "source should not read because the target is unwritable"); + + target.flush(); + assertTrue(target.isWritable(), "after a flush, the target should be writable"); + assertTrue(source.config().isAutoRead(), "after the target becomes writable, autoRead should be enabled"); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/OmnibusH2ServerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/OmnibusH2ServerTest.java new file mode 100644 index 000000000..660ff428c --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/OmnibusH2ServerTest.java @@ -0,0 +1,600 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.grpc.netty.shaded.io.netty.channel.ChannelOption; +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.SimpleUserEventChannelHandler; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalServerChannel; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioChannelOption; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.haproxy.HAProxyCommand; +import io.netty.handler.codec.haproxy.HAProxyMessage; +import io.netty.handler.codec.haproxy.HAProxyMessageEncoder; +import io.netty.handler.codec.haproxy.HAProxyProtocolVersion; +import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol; +import io.netty.handler.codec.http2.DefaultHttp2DataFrame; +import io.netty.handler.codec.http2.DefaultHttp2GoAwayFrame; +import io.netty.handler.codec.http2.DefaultHttp2Headers; +import io.netty.handler.codec.http2.DefaultHttp2HeadersFrame; +import io.netty.handler.codec.http2.DefaultHttp2ResetFrame; +import io.netty.handler.codec.http2.Http2DataFrame; +import io.netty.handler.codec.http2.Http2Error; +import io.netty.handler.codec.http2.Http2FrameCodecBuilder; +import io.netty.handler.codec.http2.Http2Headers; +import io.netty.handler.codec.http2.Http2HeadersFrame; +import io.netty.handler.codec.http2.Http2MultiplexHandler; +import io.netty.handler.codec.http2.Http2ResetFrame; +import io.netty.handler.codec.http2.Http2Settings; +import io.netty.handler.codec.http2.Http2StreamChannel; +import io.netty.handler.codec.http2.Http2StreamChannelBootstrap; +import io.netty.handler.ssl.ApplicationProtocolConfig; +import io.netty.handler.ssl.ApplicationProtocolNames; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.util.ReferenceCountUtil; +import java.io.InputStream; +import java.net.InetSocketAddress; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.IntStream; +import javax.annotation.Nullable; +import javax.net.ssl.SSLException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +class OmnibusH2ServerTest extends AbstractLeakDetectionTest { + private static final String KEYSTORE_PASSWORD = "password"; + + // Paths that start with PREFIX should go to the prefix backend, everything else to default. + private static final String PREFIX_BACKEND_IDENTITY = "prefix-backend"; + private static final String PREFIX = "/v1/prefix"; + private static final String DEFAULT_BACKEND_IDENTITY = "default-backend"; + + private final NioEventLoopGroup nioEventLoopGroup = new NioEventLoopGroup(); + private final DefaultEventLoopGroup localEventLoopGroup = new DefaultEventLoopGroup(); + + private Channel defaultBackend; + private Channel prefixBackend; + private CompletableFuture backendConnection; + + private OmnibusH2Server server; + + @BeforeEach + void setUp() throws Exception { + // Start two H2C backend servers that echo a response with their identity + defaultBackend = startH2CServer(true, DEFAULT_BACKEND_IDENTITY); + prefixBackend = startH2CServer(false, PREFIX_BACKEND_IDENTITY); + backendConnection = new CompletableFuture<>(); + + // self-signed TLS context for the frontend loaded from test keyStore + final InputStream keyStore = OmnibusH2ServerTest.class.getResourceAsStream("omnibus-h2-server-test-keystore.p12"); + + server = new OmnibusH2Server( + SniMapper.buildSniMapping(keyStore, KEYSTORE_PASSWORD), + nioEventLoopGroup, + localEventLoopGroup, + new InetSocketAddress("127.0.0.1", 0), + new OmnibusRouter( + List.of(new OmnibusRouter.OmnibusRoute(PREFIX, prefixBackend.localAddress())), + defaultBackend.localAddress()), + Duration.ofMinutes(1)); + + server.start(); + } + + @AfterEach + void tearDown() throws Exception { + server.stop(); + defaultBackend.close().sync(); + prefixBackend.close().sync(); + localEventLoopGroup.shutdownGracefully(1, 1000, TimeUnit.MILLISECONDS).sync(); + nioEventLoopGroup.shutdownGracefully(1, 1000, TimeUnit.MILLISECONDS).sync(); + } + + + @Test + void defaultBackend() { + final String response = sendRequestThroughOmnibus("/a/different/path"); + assertEquals(DEFAULT_BACKEND_IDENTITY, response); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void forwardedForHeader(final boolean usePpv2) { + final String expectedSource = usePpv2 ? "127.0.0.123" : "127.0.0.1"; + final HAProxyMessage proxyMessage = usePpv2 + ? new HAProxyMessage(HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4, expectedSource, "127.0.0.2", 1234, 5678) + : null; + final Channel h2Connection = connectToOmnibus(null, proxyMessage); + final String xForwardedFor = sendRequestThroughOmnibus(h2Connection, "/forwarded-for"); + assertEquals(expectedSource, xForwardedFor); + } + + @ParameterizedTest + @ValueSource(strings = {"/v1/prefix", "/v1/prefix/", "/v1/prefix/other"}) + void prefixBackend(final String path) { + final String response = sendRequestThroughOmnibus(path); + assertEquals(PREFIX_BACKEND_IDENTITY, response); + } + + @Test + void multipleStreamsOnSameConnection() { + final Channel h2Connection = connectToOmnibus(); + final int numStreams = 10; + + // Create concurrent streams to both backends on the same connection simultaneously + @SuppressWarnings("rawtypes") + final CompletableFuture[] futures = IntStream.range(0, numStreams) + .mapToObj(i -> CompletableFuture.supplyAsync(() -> + sendRequestThroughOmnibus( + h2Connection, + i % 2 == 0 ? PREFIX : "/v1/other"))) + .toArray(CompletableFuture[]::new); + + // Ensure we get the response from the correct backend for each stream + CompletableFuture.allOf(futures).join(); + for (int i = 0; i < numStreams; i++) { + assertEquals( + i % 2 == 0 ? PREFIX_BACKEND_IDENTITY : DEFAULT_BACKEND_IDENTITY, + futures[i].resultNow()); + } + } + + @Test + void backendDownStreamReset() { + // Kill the default backend so connection attempts from the omnibus fail + defaultBackend.close().syncUninterruptibly(); + + final Channel h2Connection = connectToOmnibus(); + final CompletableFuture headersFuture = new CompletableFuture<>(); + final Http2StreamChannel stream = new Http2StreamChannelBootstrap(h2Connection) + .handler(new HeadersCollectorHandler(headersFuture)) + .open() + .syncUninterruptibly() + .getNow(); + + final Http2Headers headers = new DefaultHttp2Headers() + .method("POST") + .path("/test") + .scheme("https") + .authority("localhost"); + stream.writeAndFlush(new DefaultHttp2HeadersFrame(headers, true)); + final Http2HeadersFrame responseHeaders = headersFuture.join(); + assertEquals("502", responseHeaders.headers().status().toString()); + + // Stream is dead, but connection should stay alive + assertFalse( + h2Connection.closeFuture().awaitUninterruptibly(5, TimeUnit.MILLISECONDS), + "connection should stay open"); + assertTrue(h2Connection.isOpen()); + + h2Connection.close().syncUninterruptibly(); + } + + @ParameterizedTest + @ValueSource(strings = {"/goaway", "/reset"}) + void backendCloseClosesClientStream(final String path) { + final Channel h2Connection = connectToOmnibus(); + final CompletableFuture resetFuture = new CompletableFuture<>(); + final Http2StreamChannel stream = new Http2StreamChannelBootstrap(h2Connection) + .handler(new RstCollectorHandler(resetFuture)) + .open() + .syncUninterruptibly() + .getNow(); + + final Http2Headers headers = new DefaultHttp2Headers() + .method("POST") + // Triggers the server stream handler to either GOAWAY+close or send an RST based on the path + .path(path) + .scheme("https") + .authority("localhost"); + stream.writeAndFlush(new DefaultHttp2HeadersFrame(headers, true)); + assertEquals(Http2Error.CANCEL.code(), resetFuture.join().errorCode()); + + // client<->omnibus h2 connection stays open after a backend close/rst + assertTrue(h2Connection.isActive()); + } + + @Test + void queuedDataFrames() throws Exception { + final Channel h2Connection = connectToOmnibus(); + final CompletableFuture responseFuture = new CompletableFuture<>(); + final Http2StreamChannel stream = new Http2StreamChannelBootstrap(h2Connection) + .handler(new ResponseCollectorHandler(responseFuture)) + .open() + .syncUninterruptibly() + .getNow(); + + final Http2Headers headers = new DefaultHttp2Headers() + .method("POST") + .path("/test") + .scheme("https") + .authority("localhost"); + + final int numFrames = 64; + + // Omnibus should handle queueing up frames while connecting to the backend if we blast all the frames right away + stream.write(new DefaultHttp2HeadersFrame(headers, false)); + final StringBuilder expectedBuilder = new StringBuilder(); + for (int i = 0; i < numFrames; i++) { + final String chunk = String.format("chunk-%03d;", i); + expectedBuilder.append(chunk); + final boolean endStream = (i == numFrames - 1); + stream.write(new DefaultHttp2DataFrame( + Unpooled.copiedBuffer(chunk, StandardCharsets.UTF_8), endStream)); + } + stream.flush(); + final String expected = expectedBuilder.toString(); + + final String response = responseFuture.get(10, TimeUnit.SECONDS); + assertEquals(expected, response); + + h2Connection.close().syncUninterruptibly(); + } + + @Test + void clientDisconnectClosesBackendConnections() throws Exception { + final Channel h2Connection = connectToOmnibus(); + + final Http2StreamChannelBootstrap streamBootstrap = new Http2StreamChannelBootstrap(h2Connection); + final Http2StreamChannel stream = streamBootstrap + .handler(new ResponseCollectorHandler(new CompletableFuture<>())) + .open() + .syncUninterruptibly() + .getNow(); + + // Write an endStream=false header so the stream stays open + final Http2Headers headers = new DefaultHttp2Headers() + .method("POST") + .path("/test") + .scheme("https") + .authority("localhost"); + stream.writeAndFlush(new DefaultHttp2HeadersFrame(headers, false)).syncUninterruptibly(); + final Channel backendServerChannel = backendConnection.join(); + assertFalse( + backendServerChannel.closeFuture().awaitUninterruptibly(10, TimeUnit.MILLISECONDS), + "Channel should be open"); + + // All backend connections the omnibus opened on behalf of this client should close if we disconnect the client + h2Connection.close().syncUninterruptibly(); + assertTrue(backendServerChannel.closeFuture().await(5, TimeUnit.SECONDS)); + } + + + @Test + void backpressure() throws ExecutionException, InterruptedException, TimeoutException { + final Channel h2Connection = connectToOmnibus(); + + // We'll take the client channel becoming unwritable as backpressure signal + final AtomicBoolean isWritable = new AtomicBoolean(true); + final CompletableFuture response = new CompletableFuture<>(); + final Http2StreamChannel stream = new Http2StreamChannelBootstrap(h2Connection) + .handler(new ChannelInboundHandlerAdapter() { + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) { + isWritable.set(ctx.channel().isWritable()); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + try { + if (msg instanceof Http2HeadersFrame headers) { + response.complete(headers); + } + } finally { + ReferenceCountUtil.release(msg); + } + } + }) + .open() + .syncUninterruptibly() + .getNow(); + + final Http2Headers headers = new DefaultHttp2Headers().method("POST").path("/test"); + stream.writeAndFlush(new DefaultHttp2HeadersFrame(headers, false)).syncUninterruptibly(); + + // Make the backend H2 stream processor 'slow' by disabling auto-read on it + final Channel backendServerChannel = backendConnection.join(); + backendServerChannel.config().setAutoRead(false); + + final byte[] chunk = new byte[16384]; + do { + // Write data until our own client hits the high watermark + while (isWritable.get()) { + // Try to wait until the write finishes but if it can't that's fine: we're trying to induce backpressure + stream + .writeAndFlush(new DefaultHttp2DataFrame(Unpooled.wrappedBuffer(chunk), false)) + .awaitUninterruptibly(100, TimeUnit.MILLISECONDS); + Thread.yield(); + } + + // Make sure our channel is still unwritable for a bit since we haven't re-enabled auto-read yet. If we become + // writable it means we are hitting a lower watermark somewhere earlier in the stack, so we can try writing some + // more. Eventually all intermediate channels should flush and we should be stuck on the backend channel which + // will never make progress (because auto-read is disabled) + Thread.sleep(100); + } while (isWritable.get()); + stream.writeAndFlush(new DefaultHttp2DataFrame(Unpooled.wrappedBuffer(chunk), true)); + + // Now re-enable reads on the backend, which should eventually unblock our writes + backendConnection.resultNow().config().setAutoRead(true); + // Now we should eventually be able to send the last (endStream=true) write and get a response + assertEquals("200", response.get(5, TimeUnit.SECONDS).headers().status().toString()); + assertTrue(isWritable.get()); + + h2Connection.close().syncUninterruptibly(); + } + + @Test + void idleTest() throws Exception { + final InputStream keyStore = OmnibusH2ServerTest.class.getResourceAsStream("omnibus-h2-server-test-keystore.p12"); + final Duration timeout = Duration.ofMillis(500); + final OmnibusH2Server timeoutServer = new OmnibusH2Server( + SniMapper.buildSniMapping(keyStore, KEYSTORE_PASSWORD), + nioEventLoopGroup, + localEventLoopGroup, + new InetSocketAddress("127.0.0.1", 0), + new OmnibusRouter( + List.of(new OmnibusRouter.OmnibusRoute(PREFIX, prefixBackend.localAddress())), + defaultBackend.localAddress()), + timeout); + timeoutServer.start(); + + final Channel channel = connectToOmnibus(timeoutServer, null); + + // Send a request to make sure idleTimeouts work even after a stream / backend connection has been established + sendRequestThroughOmnibus(channel, "/a/different/path"); + + // The server should eventually close this idle connection + assertTrue(channel.closeFuture().awaitUninterruptibly(timeout.toMillis() * 5, TimeUnit.MILLISECONDS)); + + timeoutServer.stop(); + } + + private Channel startH2CServer(final boolean local, final String identity) throws InterruptedException { + final EventLoopGroup eventLoopGroup = local ? localEventLoopGroup : nioEventLoopGroup; + return new ServerBootstrap() + .group(eventLoopGroup, eventLoopGroup) + .channel(local ? LocalServerChannel.class : NioServerSocketChannel.class) + // Limit size of kernel TCP buffers to make it easier to hit backpressure in tests + .option(NioChannelOption.SO_RCVBUF, 8192) + .option(NioChannelOption.SO_SNDBUF, 8192) + .childHandler(new ChannelInitializer<>() { + @Override + protected void initChannel(final Channel ch) { + backendConnection.complete(ch); + ch.pipeline().addLast(Http2FrameCodecBuilder.forServer().build()); + ch.pipeline().addLast(new Http2MultiplexHandler(new ChannelInitializer() { + @Override + protected void initChannel(final Http2StreamChannel ch) { + ch.pipeline().addLast(new TestHandler(identity)); + } + })); + } + }) + .bind(local ? new LocalAddress(identity) : new InetSocketAddress("127.0.0.1", 0)) + .sync() + .channel(); + } + + private Channel connectToOmnibus() { + return connectToOmnibus(null, null); + } + + /// Makes an H2 connection to the omnibus at [this#server] on which new H2 streams can be opened + private Channel connectToOmnibus(@Nullable OmnibusH2Server server, @Nullable final HAProxyMessage proxyHeader) { + if (server == null) { + server = this.server; + } + final SslContext clientSsl; + try { + clientSsl = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .applicationProtocolConfig(new ApplicationProtocolConfig( + ApplicationProtocolConfig.Protocol.ALPN, + ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE, + ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT, + ApplicationProtocolNames.HTTP_2)) + .build(); + } catch (SSLException e) { + throw new RuntimeException(e); + } + + final Bootstrap clientBootstrap = new Bootstrap() + .group(nioEventLoopGroup) + .channel(NioSocketChannel.class) + // Limit size of kernel TCP buffers to make it easier to hit backpressure in tests + .option(NioChannelOption.SO_RCVBUF, 8192) + .option(NioChannelOption.SO_SNDBUF, 8192) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(final SocketChannel ch) { + ch.pipeline().addLast(HAProxyMessageEncoder.INSTANCE); + } + }); + final Channel ch = clientBootstrap.connect(server.getLocalAddress()) + .syncUninterruptibly() + .channel(); + if (proxyHeader != null) { + ch.writeAndFlush(proxyHeader).syncUninterruptibly(); + } + ch.pipeline().remove(HAProxyMessageEncoder.INSTANCE); + ch.pipeline().addLast(clientSsl.newHandler(ch.alloc(), server.getLocalAddress().getHostName(), server.getLocalAddress().getPort())); + ch.pipeline().addLast(Http2FrameCodecBuilder.forClient() + .initialSettings(Http2Settings.defaultSettings()) + .build()); + ch.pipeline().addLast(new Http2MultiplexHandler(new ChannelInboundHandlerAdapter())); + return ch; + } + + /// Sends an H2 request to [this#server] and returns the H2 response body + private String sendRequestThroughOmnibus(final String path) { + final Channel h2Connection = connectToOmnibus(); + final String result = sendRequestThroughOmnibus(h2Connection, path); + h2Connection.close().syncUninterruptibly(); + return result; + } + + private String sendRequestThroughOmnibus(final Channel h2Connection, final String path) { + final CompletableFuture responseFuture = new CompletableFuture<>(); + final Http2StreamChannelBootstrap streamBootstrap = new Http2StreamChannelBootstrap(h2Connection); + final Http2StreamChannel stream = streamBootstrap + .handler(new ResponseCollectorHandler(responseFuture)) + .open() + .syncUninterruptibly() + .getNow(); + + final Http2Headers headers = new DefaultHttp2Headers() + .method("POST") + .path(path) + .scheme("https") + .authority("localhost"); + + stream.writeAndFlush(new DefaultHttp2HeadersFrame(headers, true)); + + return responseFuture.join(); + } + + /// A backend that either echos the request body, returns an identity, or disconnects based on the request + private static class TestHandler extends ChannelInboundHandlerAdapter { + + // Returned if request has no body + private final String identity; + private final ByteBuf accumulated = Unpooled.buffer(); + + private TestHandler(final String identity) { + this.identity = identity; + } + + @Override + public void channelRead(final ChannelHandlerContext ctx, final Object msg) { + if (msg instanceof Http2HeadersFrame headers) { + final String path = headers.headers().path().toString(); + if (path.contains("reset")) { + ctx.writeAndFlush(new DefaultHttp2ResetFrame(Http2Error.NO_ERROR)) + .addListener(ChannelFutureListener.CLOSE_ON_FAILURE); + } else if (path.contains("goaway")) { + ctx.channel().parent() + .writeAndFlush(new DefaultHttp2GoAwayFrame(Http2Error.NO_ERROR)) + .addListener(ChannelFutureListener.CLOSE); + } else if (path.contains("forwarded-for")) { + final String xForwardedFor = Optional + .ofNullable(headers.headers().get("x-forwarded-for")) + .map(CharSequence::toString) + .orElse(""); + writeResponse(ctx, Unpooled.copiedBuffer(xForwardedFor, StandardCharsets.UTF_8)); + } else if (headers.isEndStream()) { + writeResponse(ctx, Unpooled.copiedBuffer(identity, StandardCharsets.UTF_8)); + } + } else if (msg instanceof Http2DataFrame dataFrame) { + accumulated.writeBytes(dataFrame.content()); + if (dataFrame.isEndStream()) { + writeResponse(ctx, accumulated); + } + } + ReferenceCountUtil.release(msg); + } + + private void writeResponse(final ChannelHandlerContext ctx, final ByteBuf body) { + final Http2Headers responseHeaders = new DefaultHttp2Headers().status("200"); + ctx.write(new DefaultHttp2HeadersFrame(responseHeaders, false)); + ctx.writeAndFlush(new DefaultHttp2DataFrame(body, true)); + } + } + + /// Completes the provided future with the first [Http2DataFrame] received + private static class ResponseCollectorHandler extends ChannelInboundHandlerAdapter { + + private final CompletableFuture responseFuture; + + ResponseCollectorHandler(final CompletableFuture responseFuture) { + this.responseFuture = responseFuture; + } + + @Override + public void channelRead(final ChannelHandlerContext ctx, final Object msg) { + if (msg instanceof Http2DataFrame dataFrame) { + responseFuture.complete(dataFrame.content().toString(StandardCharsets.UTF_8)); + } + ReferenceCountUtil.release(msg); + } + + @Override + public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) { + responseFuture.completeExceptionally(cause); + } + } + + /// Completes the provided future with the first [Http2HeadersFrame] received + private static class HeadersCollectorHandler extends ChannelInboundHandlerAdapter { + + private final CompletableFuture responseFuture; + + HeadersCollectorHandler(final CompletableFuture responseFuture) { + this.responseFuture = responseFuture; + } + + @Override + public void channelRead(final ChannelHandlerContext ctx, final Object msg) { + if (msg instanceof Http2HeadersFrame headers) { + responseFuture.complete(headers); + } + ReferenceCountUtil.release(msg); + } + + @Override + public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) { + responseFuture.completeExceptionally(cause); + } + } + + /// Completes the provided future when an RST frame is received or errors if we don't get one + private static class RstCollectorHandler extends SimpleUserEventChannelHandler { + + private final CompletableFuture resetFuture; + + private RstCollectorHandler(final CompletableFuture resetFuture) { + this.resetFuture = resetFuture; + } + + @Override + protected void eventReceived(final ChannelHandlerContext ctx, final Http2ResetFrame evt) { + resetFuture.complete(evt); + } + + @Override + public void channelInactive(final ChannelHandlerContext ctx) { + if (!resetFuture.isDone()) { + resetFuture.completeExceptionally(new IllegalStateException("Channel went inactive without RST")); + } + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolHandlerTest.java new file mode 100644 index 000000000..0e3b2cb9b --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolHandlerTest.java @@ -0,0 +1,53 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.haproxy.HAProxyCommand; +import io.netty.handler.codec.haproxy.HAProxyMessage; +import io.netty.handler.codec.haproxy.HAProxyMessageEncoder; +import io.netty.handler.codec.haproxy.HAProxyProtocolVersion; +import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol; +import org.junit.jupiter.api.Test; + +class ProxyProtocolHandlerTest extends AbstractLeakDetectionTest { + private static final HAProxyMessage PROXY_MESSAGE = new HAProxyMessage( + HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, + HAProxyProxiedProtocol.TCP4, "10.0.0.1", "10.0.0.2", 1234, 5678); + + @Test + void sendHeader() { + final EmbeddedChannel encoder = new EmbeddedChannel(HAProxyMessageEncoder.INSTANCE); + encoder.writeOutbound(PROXY_MESSAGE.retain()); + final ByteBuf ppv2Bytes = encoder.readOutbound(); + + final EmbeddedChannel ch = new EmbeddedChannel(new ProxyProtocolHandler()); + ch.writeInbound(ppv2Bytes); + final HAProxyMessage actual = ch.readInbound(); + assertEquals(PROXY_MESSAGE.protocolVersion(), actual.protocolVersion()); + assertEquals(PROXY_MESSAGE.sourceAddress(), actual.sourceAddress()); + } + + @Test + void sendHeaderSlowly() { + final EmbeddedChannel encoder = new EmbeddedChannel(HAProxyMessageEncoder.INSTANCE); + encoder.writeOutbound(PROXY_MESSAGE.retain()); + final ByteBuf ppv2Bytes = encoder.readOutbound(); + + final EmbeddedChannel ch = new EmbeddedChannel(new ProxyProtocolHandler()); + while (ppv2Bytes.isReadable()) { + assertNull(ch.readInbound()); + ch.writeInbound(ppv2Bytes.readBytes(1)); + } + + final HAProxyMessage actual = ch.readInbound(); + assertEquals(PROXY_MESSAGE.protocolVersion(), actual.protocolVersion()); + assertEquals(PROXY_MESSAGE.sourceAddress(), actual.sourceAddress()); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/SniMapperTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/SniMapperTest.java new file mode 100644 index 000000000..545b4862e --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/SniMapperTest.java @@ -0,0 +1,179 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.SimpleUserEventChannelHandler; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import io.netty.handler.ssl.SniHandler; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.SslHandler; +import io.netty.handler.ssl.SslHandshakeCompletionEvent; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.util.Mapping; +import java.io.InputStream; +import java.security.cert.X509Certificate; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import javax.net.ssl.SNIHostName; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +class SniMapperTest { + + // Configuration for precomputed keystore blob defined in sni-mapper-test-keystore.p12 + private static final String FOO_DOMAIN = "foo.example.com"; + private static final String BAR_DOMAIN = "bar.example.com"; + private static final String KEY_STORE_PASSWORD = "password"; + private static final String KEY_STORE_NAME = "sni-mapper-test-keystore.p12"; + + private DefaultEventLoopGroup eventLoopGroup; + private Channel serverChannel; + + @BeforeEach + void setUp() throws Exception { + final InputStream keyStore = SniMapper.class.getResourceAsStream(KEY_STORE_NAME); + eventLoopGroup = new DefaultEventLoopGroup(); + + final Mapping sniMapping = + SniMapper.buildSniMapping(keyStore, KEY_STORE_PASSWORD); + + final LocalAddress localAddress = new LocalAddress(SniMapper.class.getSimpleName()); + serverChannel = new ServerBootstrap() + .group(eventLoopGroup) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer<>() { + @Override + protected void initChannel(final Channel ch) { + ch.pipeline().addLast(new SniHandler(sniMapping)); + } + }) + .bind(localAddress) + .sync() + .channel(); + } + + @AfterEach + void tearDown() throws Exception { + if (serverChannel != null) { + serverChannel.close().sync(); + } + eventLoopGroup.shutdownGracefully(1, 1000, TimeUnit.MILLISECONDS).sync(); + } + + @Test + void unknownDomain() throws Exception { + final InputStream keyStore = SniMapper.class.getResourceAsStream(KEY_STORE_NAME); + final Mapping sniMapping = SniMapper.buildSniMapping(keyStore, KEY_STORE_PASSWORD); + assertNotNull(sniMapping.map("unknown.example.com")); + final X509Certificate defaultCertificate = connectAndGetServerCertificate("unknown.example.com", null); + + // bar.example.com is the lexicographically first domain, so we should default to it. + assertCertificateIsForDomain(defaultCertificate, BAR_DOMAIN); + } + + static List selectCertificate() { + return List.of( + Arguments.of(FOO_DOMAIN, List.of(), "Ed25519"), + Arguments.of(BAR_DOMAIN, List.of(), "Ed25519"), + Arguments.of(BAR_DOMAIN, List.of("ed25519"), "Ed25519"), + Arguments.of(FOO_DOMAIN, List.of("rsa_pss_rsae_sha256", "rsa_pss_rsae_sha384", "rsa_pss_rsae_sha512"), "SHA256withRSA"), + Arguments.of(FOO_DOMAIN, List.of("rsa_pss_rsae_sha256", "rsa_pss_rsae_sha384", "rsa_pss_rsae_sha512", "ed25519"), "SHA256withRSA"), + Arguments.of(FOO_DOMAIN, List.of("ed25519", "rsa_pss_rsae_sha256", "rsa_pss_rsae_sha384", "rsa_pss_rsae_sha512"), "Ed25519") + ); + } + + @ParameterizedTest + @MethodSource + void selectCertificate(final String sni, final List signatureSchemes, final String expectedSigAlgorithm) + throws Exception { + final X509Certificate serverCert = connectAndGetServerCertificate(sni, signatureSchemes.toArray(String[]::new)); + assertNotNull(serverCert); + assertCertificateIsForDomain(serverCert, sni); + assertEquals(expectedSigAlgorithm, serverCert.getSigAlgName()); + } + + private X509Certificate connectAndGetServerCertificate(final String sniHostname, + final String[] signatureSchemes) throws Exception { + final SslContext clientSsl = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .protocols("TLSv1.3") + .build(); + + final CompletableFuture certFuture = new CompletableFuture<>(); + + final Bootstrap clientBootstrap = new Bootstrap() + .group(eventLoopGroup) + .channel(LocalChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(final LocalChannel ch) { + final SSLEngine engine = clientSsl.newEngine(ch.alloc()); + + final SSLParameters params = engine.getSSLParameters(); + params.setServerNames(List.of(new SNIHostName(sniHostname))); + if (signatureSchemes != null && signatureSchemes.length != 0) { + params.setSignatureSchemes(signatureSchemes); + } + engine.setSSLParameters(params); + + final SslHandler sslHandler = new SslHandler(engine); + ch.pipeline().addLast(sslHandler); + ch.pipeline().addLast(new SimpleUserEventChannelHandler() { + @Override + protected void eventReceived(final ChannelHandlerContext ctx, final SslHandshakeCompletionEvent evt) { + if (!evt.isSuccess()) { + certFuture.completeExceptionally(evt.cause()); + return; + } + try { + final SSLSession session = sslHandler.engine().getSession(); + final X509Certificate cert = (X509Certificate) session.getPeerCertificates()[0]; + certFuture.complete(cert); + } catch (final SSLPeerUnverifiedException e) { + certFuture.completeExceptionally(e); + } + } + }); + } + }); + + final Channel clientChannel = clientBootstrap.connect(serverChannel.localAddress()).sync().channel(); + try { + return certFuture.get(5, TimeUnit.SECONDS); + } finally { + clientChannel.close().sync(); + } + } + + private static void assertCertificateIsForDomain(final X509Certificate cert, final String expectedDomain) + throws Exception { + assertTrue(cert.getSubjectAlternativeNames().stream() + .filter(san -> (int) san.getFirst() == 2) // dNSName + .map(san -> (String) san.get(1)) + .anyMatch(name -> name.equalsIgnoreCase(expectedDomain))); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListenerIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpEventHandlerIntegrationTest.java similarity index 73% rename from service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListenerIntegrationTest.java rename to service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpEventHandlerIntegrationTest.java index 3531df0e7..57a8f6465 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListenerIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpEventHandlerIntegrationTest.java @@ -28,7 +28,6 @@ import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.Tag; import jakarta.annotation.Priority; -import jakarta.servlet.DispatcherType; import jakarta.ws.rs.GET; import jakarta.ws.rs.InternalServerErrorException; import jakarta.ws.rs.NotAuthorizedException; @@ -45,7 +44,6 @@ import java.io.IOException; import java.net.URI; import java.security.Principal; import java.time.Duration; -import java.util.EnumSet; import java.util.HashSet; import java.util.Map; import java.util.Set; @@ -55,25 +53,28 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; import java.util.stream.Stream; import javax.security.auth.Subject; -import org.eclipse.jetty.server.Connector; -import org.eclipse.jetty.server.HttpChannel; +import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer; +import org.eclipse.jetty.http.HttpFields; +import org.eclipse.jetty.server.Handler; import org.eclipse.jetty.server.Request; -import org.eclipse.jetty.util.component.Container; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.server.handler.EventsHandler; import org.eclipse.jetty.util.component.LifeCycle; +import org.eclipse.jetty.websocket.api.Callback; import org.eclipse.jetty.websocket.api.Session; -import org.eclipse.jetty.websocket.api.WebSocketListener; import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; import org.eclipse.jetty.websocket.client.WebSocketClient; -import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.mockito.ArgumentCaptor; +import org.whispersystems.textsecuregcm.filters.PriorityFilter; import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.websocket.WebSocketResourceProviderFactory; @@ -81,7 +82,8 @@ import org.whispersystems.websocket.configuration.WebSocketConfiguration; import org.whispersystems.websocket.setup.WebSocketEnvironment; @ExtendWith(DropwizardExtensionsSupport.class) -class MetricsHttpChannelListenerIntegrationTest { +@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) +class MetricsHttpEventHandlerIntegrationTest { private static final TrafficSource TRAFFIC_SOURCE = TrafficSource.HTTP; private static final MeterRegistry METER_REGISTRY = mock(MeterRegistry.class); @@ -91,7 +93,7 @@ class MetricsHttpChannelListenerIntegrationTest { private static final AtomicReference COUNT_DOWN_LATCH_FUTURE_REFERENCE = new AtomicReference<>(); private static final DropwizardAppExtension EXTENSION = new DropwizardAppExtension<>( - MetricsHttpChannelListenerIntegrationTest.TestApplication.class); + MetricsHttpEventHandlerIntegrationTest.TestApplication.class); @AfterEach void teardown() { @@ -112,9 +114,9 @@ class MetricsHttpChannelListenerIntegrationTest { final ArgumentCaptor> tagCaptor = ArgumentCaptor.forClass(Iterable.class); final Map counterMap = Map.of( - MetricsHttpChannelListener.REQUEST_COUNTER_NAME, REQUEST_COUNTER, - MetricsHttpChannelListener.RESPONSE_BYTES_COUNTER_NAME, RESPONSE_BYTES_COUNTER, - MetricsHttpChannelListener.REQUEST_BYTES_COUNTER_NAME, REQUEST_BYTES_COUNTER + MetricsHttpEventHandler.REQUEST_COUNTER_NAME, REQUEST_COUNTER, + MetricsHttpEventHandler.RESPONSE_BYTES_COUNTER_NAME, RESPONSE_BYTES_COUNTER, + MetricsHttpEventHandler.REQUEST_BYTES_COUNTER_NAME, REQUEST_BYTES_COUNTER ); when(METER_REGISTRY.counter(anyString(), any(Iterable.class))) .thenAnswer(a -> counterMap.getOrDefault(a.getArgument(0, String.class), mock(Counter.class))); @@ -148,7 +150,7 @@ class MetricsHttpChannelListenerIntegrationTest { assertTrue(countDownLatch.await(1000, TimeUnit.MILLISECONDS)); - verify(METER_REGISTRY).counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), tagCaptor.capture()); + verify(METER_REGISTRY).counter(eq(MetricsHttpEventHandler.REQUEST_COUNTER_NAME), tagCaptor.capture()); verify(REQUEST_COUNTER).increment(); final Iterable tagIterable = tagCaptor.getValue(); @@ -159,10 +161,10 @@ class MetricsHttpChannelListenerIntegrationTest { } assertEquals(Set.of( - Tag.of(MetricsHttpChannelListener.PATH_TAG, expectedTagPath), - Tag.of(MetricsHttpChannelListener.METHOD_TAG, "GET"), - Tag.of(MetricsHttpChannelListener.STATUS_CODE_TAG, String.valueOf(expectedStatus)), - Tag.of(MetricsHttpChannelListener.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase()), + Tag.of(MetricsHttpEventHandler.PATH_TAG, expectedTagPath), + Tag.of(MetricsHttpEventHandler.METHOD_TAG, "GET"), + Tag.of(MetricsHttpEventHandler.STATUS_CODE_TAG, String.valueOf(expectedStatus)), + Tag.of(MetricsHttpEventHandler.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase()), Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android")), tags); } @@ -186,7 +188,7 @@ class MetricsHttpChannelListenerIntegrationTest { @Test void testWebSocketUpgrade() throws Exception { - final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest(); + final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest(URI.create(String.format("ws://localhost:%d%s", EXTENSION.getLocalPort(), "/v1/websocket"))); upgradeRequest.setHeader(HttpHeaders.USER_AGENT, "Signal-Android/4.53.7 (Android 8.1)"); final CountDownLatch countDownLatch = new CountDownLatch(1); @@ -194,24 +196,18 @@ class MetricsHttpChannelListenerIntegrationTest { final ArgumentCaptor> tagCaptor = ArgumentCaptor.forClass(Iterable.class); final Map counterMap = Map.of( - MetricsHttpChannelListener.REQUEST_COUNTER_NAME, REQUEST_COUNTER, - MetricsHttpChannelListener.RESPONSE_BYTES_COUNTER_NAME, RESPONSE_BYTES_COUNTER, - MetricsHttpChannelListener.REQUEST_BYTES_COUNTER_NAME, REQUEST_BYTES_COUNTER + MetricsHttpEventHandler.REQUEST_COUNTER_NAME, REQUEST_COUNTER, + MetricsHttpEventHandler.RESPONSE_BYTES_COUNTER_NAME, RESPONSE_BYTES_COUNTER, + MetricsHttpEventHandler.REQUEST_BYTES_COUNTER_NAME, REQUEST_BYTES_COUNTER ); when(METER_REGISTRY.counter(anyString(), any(Iterable.class))) .thenAnswer(a -> counterMap.getOrDefault(a.getArgument(0, String.class), mock(Counter.class))); - client.connect(new WebSocketListener() { - @Override - public void onWebSocketConnect(final Session session) { - session.close(1000, "OK"); - } - }, - URI.create(String.format("ws://localhost:%d%s", EXTENSION.getLocalPort(), "/v1/websocket")), upgradeRequest); + client.connect(new AutoClosingWebSocketSessionListener(), upgradeRequest); assertTrue(countDownLatch.await(1000, TimeUnit.MILLISECONDS)); - verify(METER_REGISTRY).counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), tagCaptor.capture()); + verify(METER_REGISTRY).counter(eq(MetricsHttpEventHandler.REQUEST_COUNTER_NAME), tagCaptor.capture()); verify(REQUEST_COUNTER).increment(); final Iterable tagIterable = tagCaptor.getValue(); @@ -222,10 +218,10 @@ class MetricsHttpChannelListenerIntegrationTest { } assertEquals(Set.of( - Tag.of(MetricsHttpChannelListener.PATH_TAG, "/v1/websocket"), - Tag.of(MetricsHttpChannelListener.METHOD_TAG, "GET"), - Tag.of(MetricsHttpChannelListener.STATUS_CODE_TAG, String.valueOf(101)), - Tag.of(MetricsHttpChannelListener.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase()), + Tag.of(MetricsHttpEventHandler.PATH_TAG, "/v1/websocket"), + Tag.of(MetricsHttpEventHandler.METHOD_TAG, "GET"), + Tag.of(MetricsHttpEventHandler.STATUS_CODE_TAG, String.valueOf(101)), + Tag.of(MetricsHttpEventHandler.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase()), Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android")), tags); } @@ -248,9 +244,9 @@ class MetricsHttpChannelListenerIntegrationTest { final ArgumentCaptor> tagCaptor = ArgumentCaptor.forClass(Iterable.class); final Map counterMap = Map.of( - MetricsHttpChannelListener.REQUEST_COUNTER_NAME, REQUEST_COUNTER, - MetricsHttpChannelListener.RESPONSE_BYTES_COUNTER_NAME, RESPONSE_BYTES_COUNTER, - MetricsHttpChannelListener.REQUEST_BYTES_COUNTER_NAME, REQUEST_BYTES_COUNTER + MetricsHttpEventHandler.REQUEST_COUNTER_NAME, REQUEST_COUNTER, + MetricsHttpEventHandler.RESPONSE_BYTES_COUNTER_NAME, RESPONSE_BYTES_COUNTER, + MetricsHttpEventHandler.REQUEST_BYTES_COUNTER_NAME, REQUEST_BYTES_COUNTER ); when(METER_REGISTRY.counter(anyString(), any(Iterable.class))) .thenAnswer(a -> counterMap.getOrDefault(a.getArgument(0, String.class), mock(Counter.class))); @@ -263,7 +259,7 @@ class MetricsHttpChannelListenerIntegrationTest { assertTrue(countDownLatch.await(1000, TimeUnit.MILLISECONDS)); - verify(METER_REGISTRY).counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), tagCaptor.capture()); + verify(METER_REGISTRY).counter(eq(MetricsHttpEventHandler.REQUEST_COUNTER_NAME), tagCaptor.capture()); verify(REQUEST_COUNTER).increment(); final Iterable tagIterable = tagCaptor.getValue(); @@ -273,7 +269,7 @@ class MetricsHttpChannelListenerIntegrationTest { tags.add(tag); } - assertFalse(tags.contains(Tag.of(MetricsHttpChannelListener.METHOD_TAG, "ANNOY"))); + assertFalse(tags.contains(Tag.of(MetricsHttpEventHandler.METHOD_TAG, "ANNOY"))); } } @@ -283,17 +279,16 @@ class MetricsHttpChannelListenerIntegrationTest { public void run(final Configuration configuration, final Environment environment) throws Exception { - final MetricsHttpChannelListener metricsHttpChannelListener = new MetricsHttpChannelListener( - METER_REGISTRY, - mock(ClientReleaseManager.class), - Set.of("/v1/websocket") - ); + MetricsHttpEventHandler.configure(environment, METER_REGISTRY, mock(ClientReleaseManager.class), Set.of("/v1/websocket")); - metricsHttpChannelListener.configure(environment); - environment.lifecycle().addEventListener(new TestListener(COUNT_DOWN_LATCH_FUTURE_REFERENCE)); - - environment.servlets().addFilter("RemoteAddressFilter", new RemoteAddressFilter()) - .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*"); + environment.lifecycle().addEventListener(new LifeCycle.Listener() { + @Override + public void lifeCycleStarting(final LifeCycle event) { + if (event instanceof Server server) { + server.setHandler(new TestListener(server.getHandler(), COUNT_DOWN_LATCH_FUTURE_REFERENCE)); + } + } + }); environment.jersey().register(new TestResource()); environment.jersey().register(new TestAuthFilter()); @@ -306,14 +301,15 @@ class MetricsHttpChannelListenerIntegrationTest { webSocketEnvironment.jersey().register(new TestResource()); - JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null); - WebSocketResourceProviderFactory webSocketServlet = new WebSocketResourceProviderFactory<>( - webSocketEnvironment, TestPrincipal.class, webSocketConfiguration, + webSocketEnvironment, TestPrincipal.class, RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME); - environment.servlets().addServlet("WebSocket", webSocketServlet) - .addMapping("/v1/websocket"); + JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), + (servletContext, container) -> { + container.addMapping("/v1/websocket", webSocketServlet); + PriorityFilter.ensureFilter(servletContext, new RemoteAddressFilter()); + }); } } @@ -329,36 +325,23 @@ class MetricsHttpChannelListenerIntegrationTest { } /** - * A simple listener to signal that {@link HttpChannel.Listener} has completed its work, since its onComplete() is on + * A simple listener to signal that {@link EventsHandler} has completed its work, since its onComplete() is on * a different thread from the one that sends the response, creating a race condition between the listener and the * test assertions */ - static class TestListener implements HttpChannel.Listener, Container.Listener, LifeCycle.Listener { + static class TestListener extends EventsHandler { private final AtomicReference completableFutureAtomicReference; - TestListener(AtomicReference countDownLatchReference) { - + TestListener(final Handler handler, AtomicReference countDownLatchReference) { + super(handler); this.completableFutureAtomicReference = countDownLatchReference; } @Override - public void onComplete(final Request request) { + public void onComplete(Request request, int status, HttpFields headers, Throwable failure) { completableFutureAtomicReference.get().countDown(); } - - @Override - public void beanAdded(final Container parent, final Object child) { - if (child instanceof Connector connector) { - connector.addBean(this); - } - } - - @Override - public void beanRemoved(final Container parent, final Object child) { - - } - } @Path("/v1/test") @@ -400,4 +383,11 @@ class MetricsHttpChannelListenerIntegrationTest { } } + public static class AutoClosingWebSocketSessionListener implements Session.Listener.AutoDemanding { + @Override + public void onWebSocketOpen(final Session session) { + session.close(1000, "OK", Callback.NOOP); + } + } + } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListenerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpEventHandlerTest.java similarity index 53% rename from service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListenerTest.java rename to service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpEventHandlerTest.java index cc711fb96..8c692072c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListenerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpEventHandlerTest.java @@ -20,24 +20,33 @@ import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tags; import java.util.Collections; +import java.nio.ByteBuffer; import java.util.HashSet; import java.util.List; import java.util.Set; +import org.eclipse.jetty.http.HttpFields; +import org.eclipse.jetty.http.HttpHeader; import org.eclipse.jetty.http.HttpURI; +import org.eclipse.jetty.io.Content; import org.eclipse.jetty.server.Request; import org.eclipse.jetty.server.Response; +import org.glassfish.jersey.server.ContainerRequest; +import org.glassfish.jersey.server.ContainerResponse; import org.glassfish.jersey.server.ExtendedUriInfo; import org.glassfish.jersey.uri.UriTemplate; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; +import org.junitpioneer.jupiter.cartesian.CartesianTest; import org.mockito.ArgumentCaptor; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import javax.annotation.Nullable; -class MetricsHttpChannelListenerTest { +class MetricsHttpEventHandlerTest { + private final static String USER_AGENT = "Signal-Android/6.53.7 (Android 8.1)"; private MeterRegistry meterRegistry; private Counter requestCounter; @@ -45,7 +54,7 @@ class MetricsHttpChannelListenerTest { private Counter responseBytesCounter; private Counter requestBytesCounter; private ClientReleaseManager clientReleaseManager; - private MetricsHttpChannelListener listener; + private MetricsHttpEventHandler listener; @BeforeEach void setup() { @@ -55,27 +64,28 @@ class MetricsHttpChannelListenerTest { responseBytesCounter = mock(Counter.class); requestBytesCounter = mock(Counter.class); - when(meterRegistry.counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), any(Iterable.class))) + when(meterRegistry.counter(eq(MetricsHttpEventHandler.REQUEST_COUNTER_NAME), any(Iterable.class))) .thenReturn(requestCounter); - when(meterRegistry.counter(eq(MetricsHttpChannelListener.REQUESTS_BY_VERSION_COUNTER_NAME), any(Tags.class))) + when(meterRegistry.counter(eq(MetricsHttpEventHandler.REQUESTS_BY_VERSION_COUNTER_NAME), any(Tags.class))) .thenReturn(requestsByVersionCounter); - when(meterRegistry.counter(eq(MetricsHttpChannelListener.RESPONSE_BYTES_COUNTER_NAME), any(Iterable.class))) + when(meterRegistry.counter(eq(MetricsHttpEventHandler.RESPONSE_BYTES_COUNTER_NAME), any(Iterable.class))) .thenReturn(responseBytesCounter); - when(meterRegistry.counter(eq(MetricsHttpChannelListener.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class))) + when(meterRegistry.counter(eq(MetricsHttpEventHandler.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class))) .thenReturn(requestBytesCounter); clientReleaseManager = mock(ClientReleaseManager.class); - listener = new MetricsHttpChannelListener(meterRegistry, clientReleaseManager, Collections.emptySet()); + listener = new MetricsHttpEventHandler(null, meterRegistry, clientReleaseManager, Set.of("/test")); } - @ParameterizedTest - @ValueSource(booleans = {true, false}) + @CartesianTest @SuppressWarnings("unchecked") - void testRequests(final boolean versionActive) { + void testRequests(@CartesianTest.Values(booleans = {true, false}) final boolean pathFromFilter, + @CartesianTest.Values(booleans = {true, false}) final boolean versionActive) { + final String path = "/test"; final String method = "GET"; final int statusCode = 200; @@ -85,30 +95,40 @@ class MetricsHttpChannelListenerTest { final Request request = mock(Request.class); when(request.getMethod()).thenReturn(method); - when(request.getHeader(HttpHeaders.USER_AGENT)).thenReturn("Signal-Android/6.53.7 (Android 8.1)"); + + final HttpFields.Mutable requestHeaders = HttpFields.build(); + requestHeaders.put(HttpHeader.USER_AGENT, USER_AGENT); + when(request.getHeaders()).thenReturn(requestHeaders); when(request.getHttpURI()).thenReturn(httpUri); + if (pathFromFilter) { + when(request.getAttribute(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME)) + .thenReturn(new MetricsHttpEventHandler.RequestInfo(path, method, USER_AGENT)); + } else { + when(request.setAttribute(eq(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME), any())).thenAnswer(invocation -> { + when(request.getAttribute(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME)) + .thenReturn(invocation.getArgument(1)); + return null; + }); + } + final Response response = mock(Response.class); when(response.getStatus()).thenReturn(statusCode); - when(clientReleaseManager.isVersionActive(any(), any())).thenReturn(versionActive); - when(response.getContentCount()).thenReturn(1024L); - when(request.getResponse()).thenReturn(response); - when(request.getContentRead()).thenReturn(512L); - final ExtendedUriInfo extendedUriInfo = mock(ExtendedUriInfo.class); - when(request.getAttribute(MetricsHttpChannelListener.URI_INFO_PROPERTY_NAME)).thenReturn(extendedUriInfo); - when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate(path))); + when(clientReleaseManager.isVersionActive(any(), any())).thenReturn(versionActive); final ArgumentCaptor> tagCaptor = ArgumentCaptor.forClass(Iterable.class); - listener.onComplete(request); + listener.onRequestRead(request, Content.Chunk.from(ByteBuffer.allocate(512), true)); + listener.onResponseWrite(request, true, ByteBuffer.allocate(1024)); + listener.onComplete(request, statusCode, requestHeaders, null); verify(requestCounter).increment(); verify(responseBytesCounter).increment(1024L); verify(requestBytesCounter).increment(512L); - verify(meterRegistry).counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), tagCaptor.capture()); + verify(meterRegistry).counter(eq(MetricsHttpEventHandler.REQUEST_COUNTER_NAME), tagCaptor.capture()); final Set tags = new HashSet<>(); for (final Tag tag : tagCaptor.getValue()) { @@ -116,10 +136,10 @@ class MetricsHttpChannelListenerTest { } final Set expectedTags = new HashSet<>(Set.of( - Tag.of(MetricsHttpChannelListener.PATH_TAG, path), - Tag.of(MetricsHttpChannelListener.METHOD_TAG, method), - Tag.of(MetricsHttpChannelListener.STATUS_CODE_TAG, String.valueOf(statusCode)), - Tag.of(MetricsHttpChannelListener.TRAFFIC_SOURCE_TAG, TrafficSource.HTTP.name().toLowerCase()), + Tag.of(MetricsHttpEventHandler.PATH_TAG, path), + Tag.of(MetricsHttpEventHandler.METHOD_TAG, method), + Tag.of(MetricsHttpEventHandler.STATUS_CODE_TAG, String.valueOf(statusCode)), + Tag.of(MetricsHttpEventHandler.TRAFFIC_SOURCE_TAG, TrafficSource.HTTP.name().toLowerCase()), Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android"))); if (versionActive) { @@ -143,23 +163,22 @@ class MetricsHttpChannelListenerTest { final Request request = mock(Request.class); when(request.getMethod()).thenReturn(method); - when(request.getHeader(HttpHeaders.USER_AGENT)).thenReturn("Signal-Android/6.53.7 (Android 8.1)"); + final HttpFields.Mutable requestHeaders = HttpFields.build(); + requestHeaders.put(HttpHeader.USER_AGENT, USER_AGENT); + when(request.getHeaders()).thenReturn(requestHeaders); when(request.getHttpURI()).thenReturn(httpUri); + when(request.getAttribute(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME)) + .thenReturn(new MetricsHttpEventHandler.RequestInfo(path, method, USER_AGENT)); final Response response = mock(Response.class); when(response.getStatus()).thenReturn(statusCode); - when(response.getContentCount()).thenReturn(1024L); - when(request.getResponse()).thenReturn(response); - when(request.getContentRead()).thenReturn(512L); - final ExtendedUriInfo extendedUriInfo = mock(ExtendedUriInfo.class); - when(request.getAttribute(MetricsHttpChannelListener.URI_INFO_PROPERTY_NAME)).thenReturn(extendedUriInfo); - when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate(path))); - - listener.onComplete(request); + listener.onRequestRead(request, Content.Chunk.from(ByteBuffer.allocate(512), true)); + listener.onResponseWrite(request, true, ByteBuffer.allocate(1024)); + listener.onComplete(request, statusCode, requestHeaders, null); if (versionActive) { final ArgumentCaptor tagCaptor = ArgumentCaptor.forClass(Tags.class); - verify(meterRegistry).counter(eq(MetricsHttpChannelListener.REQUESTS_BY_VERSION_COUNTER_NAME), + verify(meterRegistry).counter(eq(MetricsHttpEventHandler.REQUESTS_BY_VERSION_COUNTER_NAME), tagCaptor.capture()); final Set tags = new HashSet<>(); tags.clear(); @@ -178,7 +197,7 @@ class MetricsHttpChannelListenerTest { @ParameterizedTest @MethodSource void normalizeMethod(@Nullable final String originalMethod, final String expectedMethod) { - assertEquals(expectedMethod, MetricsHttpChannelListener.normalizeMethod(originalMethod)); + assertEquals(expectedMethod, MetricsHttpEventHandler.normalizeMethod(originalMethod)); } private static List normalizeMethod() { @@ -190,4 +209,37 @@ class MetricsHttpChannelListenerTest { Arguments.arguments("get", "get") ); } + + @Test + void testResponseFilterSetsRequestInfo() { + final ContainerRequest request = mock(ContainerRequest.class); + + final ExtendedUriInfo extendedUriInfo = mock(ExtendedUriInfo.class); + when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate("/test"))); + when(request.getMethod()).thenReturn("GET"); + when(request.getHeaders()).thenReturn(null); + when(request.getUriInfo()).thenReturn(extendedUriInfo); + when(request.getHeaderString(HttpHeaders.USER_AGENT)).thenReturn(USER_AGENT); + + new MetricsHttpEventHandler.SetInfoRequestFilter().filter(request, mock(ContainerResponse.class)); + + verify(request).setProperty( + eq(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME), + eq(new MetricsHttpEventHandler.RequestInfo("/test", "GET", USER_AGENT))); + } + + @Test + void testResponseFilterModifiesRequestInfo() { + final MetricsHttpEventHandler.RequestInfo requestInfo = + new MetricsHttpEventHandler.RequestInfo("unknown", "POST", USER_AGENT); + + final ContainerRequest request = mock(ContainerRequest.class); + when(request.getProperty(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME)).thenReturn(requestInfo); + final ExtendedUriInfo extendedUriInfo = mock(ExtendedUriInfo.class); + when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate("/test"))); + when(request.getUriInfo()).thenReturn(extendedUriInfo); + new MetricsHttpEventHandler.SetInfoRequestFilter().filter(request, mock(ContainerResponse.class)); + + assertEquals(new MetricsHttpEventHandler.RequestInfo("/test", "POST", USER_AGENT), requestInfo); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java index 46f3bcace..222c7e251 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java @@ -36,10 +36,9 @@ import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; -import org.eclipse.jetty.websocket.api.RemoteEndpoint; +import org.eclipse.jetty.websocket.api.Callback; import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.UpgradeRequest; -import org.eclipse.jetty.websocket.api.WriteCallback; import org.glassfish.jersey.server.ApplicationHandler; import org.glassfish.jersey.server.ContainerRequest; import org.glassfish.jersey.server.ContainerResponse; @@ -175,11 +174,9 @@ class MetricsRequestEventListenerTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); final Session session = mock(Session.class); - final RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); final UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); when(request.getHeader(HttpHeaders.USER_AGENT)).thenReturn("Signal-Android/4.53.7 (Android 8.1)"); when(request.getHeaders()).thenReturn(Map.of(HttpHeaders.USER_AGENT, List.of("Signal-Android/4.53.7 (Android 8.1)"))); @@ -191,15 +188,15 @@ class MetricsRequestEventListenerTest { when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class))) .thenReturn(requestBytesCounter); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello", new LinkedList<>(), Optional.empty()).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); final ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); @@ -244,11 +241,9 @@ class MetricsRequestEventListenerTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); final Session session = mock(Session.class); - final RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); final UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); final ArgumentCaptor> tagCaptor = ArgumentCaptor.forClass(Iterable.class); when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_COUNTER_NAME), any(Iterable.class))).thenReturn( @@ -258,15 +253,15 @@ class MetricsRequestEventListenerTest { when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class))) .thenReturn(requestBytesCounter); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); final byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello", new LinkedList<>(), Optional.empty()).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); final ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); @@ -314,11 +309,9 @@ class MetricsRequestEventListenerTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); final Session session = mock(Session.class); - final RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); final UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); final ArgumentCaptor> tagCaptor = ArgumentCaptor.forClass(Iterable.class); when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_COUNTER_NAME), any(Iterable.class))).thenReturn( @@ -328,15 +321,15 @@ class MetricsRequestEventListenerTest { when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class))) .thenReturn(requestBytesCounter); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); final byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello", new LinkedList<>(), Optional.empty()).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); final ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/TlsCertificateExpirationUtilTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/TlsCertificateExpirationUtilTest.java index 3b3c53adf..41c6c088f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/TlsCertificateExpirationUtilTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/TlsCertificateExpirationUtilTest.java @@ -149,15 +149,13 @@ class TlsCertificateExpirationUtilTest { @Test void test() throws Exception { - try (Resource keystore = TestResource.fromBase64Mime("keystore", KEYSTORE_BASE64)) { + final Resource keystore = TestResource.fromBase64Mime("keystore", KEYSTORE_BASE64); + final KeyStore keyStore = CertificateUtils.getKeyStore(keystore, "PKCS12", null, KEYSTORE_PASSWORD); - final KeyStore keyStore = CertificateUtils.getKeyStore(keystore, "PKCS12", null, KEYSTORE_PASSWORD); - - final Map expected = Map.of( - "localhost:EdDSA", EDDSA_EXPIRATION, - "localhost:RSA", RSA_EXPIRATION); - assertEquals(expected, TlsCertificateExpirationUtil.getIdentifiersAndExpirations(keyStore, KEYSTORE_PASSWORD)); - } + final Map expected = Map.of( + "localhost:EdDSA", EDDSA_EXPIRATION, + "localhost:RSA", RSA_EXPIRATION); + assertEquals(expected, TlsCertificateExpirationUtil.getIdentifiersAndExpirations(keyStore, KEYSTORE_PASSWORD)); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClusterClientTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClusterClientTest.java index 5d1d9cbe5..8b97f100f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClusterClientTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClusterClientTest.java @@ -419,13 +419,6 @@ class FaultTolerantRedisClusterClientTest { return this; } - @Override - @Deprecated - public ClientResources.Builder dnsResolver(final DnsResolver dnsResolver) { - delegate.dnsResolver(dnsResolver); - return this; - } - @Override public ClientResources.Builder eventBus(final EventBus eventBus) { delegate.eventBus(eventBus); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/TestWebsocketListener.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/TestWebsocketListener.java index 4c1cdf3f9..44d904646 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/TestWebsocketListener.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/TestWebsocketListener.java @@ -4,14 +4,6 @@ */ package org.whispersystems.textsecuregcm.tests.util; -import org.eclipse.jetty.websocket.api.Session; -import org.eclipse.jetty.websocket.api.WebSocketListener; -import org.whispersystems.websocket.messages.WebSocketMessage; -import org.whispersystems.websocket.messages.WebSocketMessageFactory; -import org.whispersystems.websocket.messages.WebSocketResponseMessage; -import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory; - -import java.io.IOException; import java.nio.ByteBuffer; import java.util.List; import java.util.Objects; @@ -19,8 +11,14 @@ import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; +import org.eclipse.jetty.websocket.api.Callback; +import org.eclipse.jetty.websocket.api.Session; +import org.whispersystems.websocket.messages.WebSocketMessage; +import org.whispersystems.websocket.messages.WebSocketMessageFactory; +import org.whispersystems.websocket.messages.WebSocketResponseMessage; +import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory; -public class TestWebsocketListener implements WebSocketListener { +public class TestWebsocketListener implements Session.Listener.AutoDemanding { private final AtomicLong requestId = new AtomicLong(); private final CompletableFuture started = new CompletableFuture<>(); @@ -34,7 +32,7 @@ public class TestWebsocketListener implements WebSocketListener { @Override - public void onWebSocketConnect(final Session session) { + public void onWebSocketOpen(final Session session) { started.complete(session); } @@ -63,19 +61,15 @@ public class TestWebsocketListener implements WebSocketListener { responseFutures.put(id, future); final byte[] requestBytes = messageFactory.createRequest( Optional.of(id), verb, requestPath, headers, body).toByteArray(); - try { - session.getRemote().sendBytes(ByteBuffer.wrap(requestBytes)); - } catch (IOException e) { - throw new RuntimeException(e); - } + session.sendBinary(ByteBuffer.wrap(requestBytes), Callback.NOOP); return future; }); } @Override - public void onWebSocketBinary(final byte[] payload, final int offset, final int length) { + public void onWebSocketBinary(final ByteBuffer payload, final Callback callback) { try { - WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length); + WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload); if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.RESPONSE_MESSAGE) { responseFutures.get(webSocketMessage.getResponseMessage().getRequestId()) .complete(webSocketMessage.getResponseMessage()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/jetty/TestResource.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/jetty/TestResource.java index 831c4e526..e02282a65 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/jetty/TestResource.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/jetty/TestResource.java @@ -6,12 +6,9 @@ package org.whispersystems.textsecuregcm.util.jetty; import java.io.ByteArrayInputStream; -import java.io.File; -import java.io.IOException; import java.io.InputStream; -import java.net.MalformedURLException; import java.net.URI; -import java.nio.channels.ReadableByteChannel; +import java.nio.file.Path; import java.util.Base64; import org.eclipse.jetty.util.resource.Resource; @@ -30,15 +27,16 @@ public class TestResource extends Resource { } @Override - public boolean isContainedIn(final Resource r) throws MalformedURLException { - return false; + public Path getPath() { + return null; } @Override - public void close() { - + public InputStream newInputStream() { + return new ByteArrayInputStream(data); } + @Override public boolean exists() { return true; @@ -50,13 +48,13 @@ public class TestResource extends Resource { } @Override - public long lastModified() { - return 0; + public boolean isReadable() { + return true; } @Override public long length() { - return 0; + return data.length; } @Override @@ -64,43 +62,19 @@ public class TestResource extends Resource { return null; } - @Override - public File getFile() throws IOException { - return null; - } - @Override public String getName() { return name; } @Override - public InputStream getInputStream() throws IOException { - return new ByteArrayInputStream(data); + public String getFileName() { + return ""; } @Override - public ReadableByteChannel getReadableByteChannel() throws IOException { + public Resource resolve(final String subUriPath) { return null; } - @Override - public boolean delete() throws SecurityException { - return false; - } - - @Override - public boolean renameTo(final Resource dest) throws SecurityException { - return false; - } - - @Override - public String[] list() { - return new String[]{name}; - } - - @Override - public Resource addPath(final String path) throws IOException, MalformedURLException { - return this; - } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapperTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapperTest.java index ca4a17cf7..3970e868c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapperTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapperTest.java @@ -39,10 +39,9 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import java.util.stream.Stream; -import org.eclipse.jetty.websocket.api.RemoteEndpoint; +import org.eclipse.jetty.websocket.api.Callback; import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.UpgradeRequest; -import org.eclipse.jetty.websocket.api.WriteCallback; import org.glassfish.jersey.server.ApplicationHandler; import org.glassfish.jersey.server.ResourceConfig; import org.glassfish.jersey.server.ServerProperties; @@ -143,12 +142,12 @@ class LoggingUnhandledExceptionMapperTest { WebSocketResourceProvider provider = createWebsocketProvider(userAgentHeader, session, responseFuture::complete); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); byte[] message = new ProtobufWebSocketMessageFactory() .createRequest(Optional.of(111L), "GET", targetPath, new LinkedList<>(), Optional.empty()).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); responseFuture.get(1, TimeUnit.SECONDS); @@ -179,15 +178,13 @@ class LoggingUnhandledExceptionMapperTest { TestPrincipal.authenticatedTestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); - RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); doAnswer(answer -> { responseHandler.accept(answer.getArgument(0, ByteBuffer.class)); return null; - }).when(remoteEndpoint).sendBytes(any(), any(WriteCallback.class)); + }).when(session).sendBinary(any(ByteBuffer.class), any(Callback.class)); UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); when(request.getHeader(HttpHeaders.USER_AGENT)).thenReturn(userAgentHeader); when(request.getHeaders()).thenReturn(Map.of(HttpHeaders.USER_AGENT, List.of(userAgentHeader))); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticatorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticatorTest.java index 4ecd722f3..55e7e7cc3 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticatorTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticatorTest.java @@ -19,7 +19,7 @@ import java.util.Optional; import java.util.UUID; import java.util.stream.Stream; import javax.annotation.Nullable; -import org.eclipse.jetty.websocket.api.UpgradeRequest; +import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; @@ -44,7 +44,7 @@ class WebSocketAccountAuthenticatorTest { private AccountAuthenticator accountAuthenticator; - private UpgradeRequest upgradeRequest; + private JettyServerUpgradeRequest upgradeRequest; @BeforeEach void setUp() { @@ -56,7 +56,7 @@ class WebSocketAccountAuthenticatorTest { when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD)))) .thenReturn(Optional.empty()); - upgradeRequest = mock(UpgradeRequest.class); + upgradeRequest = mock(JettyServerUpgradeRequest.class); } @ParameterizedTest diff --git a/service/src/test/resources/config/test.yml b/service/src/test/resources/config/test.yml index 9fb5328bd..fadd66ed9 100644 --- a/service/src/test/resources/config/test.yml +++ b/service/src/test/resources/config/test.yml @@ -337,6 +337,8 @@ dynamicConfig: object: | captcha: scoreFloor: 1.0 + grpcAllowList: + enableAll: true remoteConfig: globalConfig: # keys and values that are given to clients on GET /v1/config @@ -521,6 +523,8 @@ idlePrimaryDeviceReminder: grpc: port: 50051 + websocketPort: 8080 + h2c: true asnTable: s3Region: a-region @@ -536,3 +540,18 @@ callQualitySurvey: hlrLookup: apiKey: secret://hlrLookup.apiKey apiSecret: secret://hlrLookup.apiSecret + +server: + allowedMethods: + - GET + - POST + - PUT + - DELETE + - HEAD + - OPTIONS + - PATCH + - CONNECT + applicationConnectors: + - type: h2c + port: 8080 + useForwardedHeaders: true diff --git a/service/src/test/resources/org/whispersystems/textsecuregcm/grpc/net/generate-omnibus-h2-server-test-certs.sh b/service/src/test/resources/org/whispersystems/textsecuregcm/grpc/net/generate-omnibus-h2-server-test-certs.sh new file mode 100755 index 000000000..5422e23bc --- /dev/null +++ b/service/src/test/resources/org/whispersystems/textsecuregcm/grpc/net/generate-omnibus-h2-server-test-certs.sh @@ -0,0 +1,19 @@ +#!/bin/sh +# Generates self-signed local testing certificates for OmnibusH2ServerTest + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +WORK_DIR="$(mktemp -d)" +trap 'rm -rf "$WORK_DIR"' EXIT + +PASSWORD="password" +DAYS=36500 +OMNIBUS_KS="$SCRIPT_DIR/omnibus-h2-server-test-keystore.p12" + +openssl genpkey -algorithm RSA -pkeyopt rsa_keygen_bits:2048 -out "$WORK_DIR/foo-rsa.key" 2>/dev/null; +openssl req -new -x509 -key "$WORK_DIR/foo-rsa.key" -out "$WORK_DIR/foo-rsa.crt" -days "$DAYS" -subj "/CN=foo.example.com" -addext "subjectAltName=DNS:foo.example.com" +openssl pkcs12 -export -in "$WORK_DIR/foo-rsa.crt" -inkey "$WORK_DIR/foo-rsa.key" -out "$WORK_DIR/foo-rsa.p12" -name foo -passout pass:$PASSWORD +keytool -importkeystore -noprompt -srckeystore "$WORK_DIR/foo-rsa.p12" -srcstoretype PKCS12 -srcstorepass $PASSWORD -destkeystore "$OMNIBUS_KS" -deststoretype PKCS12 -deststorepass $PASSWORD + +echo "Wrote keystore to $OMNIBUS_KS" diff --git a/service/src/test/resources/org/whispersystems/textsecuregcm/grpc/net/generate-sni-mapping-test-certs.sh b/service/src/test/resources/org/whispersystems/textsecuregcm/grpc/net/generate-sni-mapping-test-certs.sh new file mode 100755 index 000000000..418e5766b --- /dev/null +++ b/service/src/test/resources/org/whispersystems/textsecuregcm/grpc/net/generate-sni-mapping-test-certs.sh @@ -0,0 +1,36 @@ +#!/bin/sh +# Generates self-signed local testing certificates for SniMapperTest + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +WORK_DIR="$(mktemp -d)" +trap 'rm -rf "$WORK_DIR"' EXIT + +PASSWORD="password" +DAYS=36500 +SNI_KS="$SCRIPT_DIR/sni-mapper-test-keystore.p12" + +openssl genpkey -algorithm RSA -pkeyopt rsa_keygen_bits:2048 -out "$WORK_DIR/foo-rsa.key" 2>/dev/null; +openssl req -new -x509 -key "$WORK_DIR/foo-rsa.key" -out "$WORK_DIR/foo-rsa.crt" -days "$DAYS" -subj "/CN=foo.example.com" -addext "subjectAltName=DNS:foo.example.com" + +openssl genpkey -algorithm Ed25519 -out "$WORK_DIR/foo-ed25519.key" 2>/dev/null; +openssl req -new -x509 -key "$WORK_DIR/foo-ed25519.key" -out "$WORK_DIR/foo-ed25519.crt" -days "$DAYS" -subj "/CN=foo.example.com" -addext "subjectAltName=DNS:foo.example.com" + +openssl genpkey -algorithm RSA -pkeyopt rsa_keygen_bits:2048 -out "$WORK_DIR/bar-rsa.key" 2>/dev/null; +openssl req -new -x509 -key "$WORK_DIR/bar-rsa.key" -out "$WORK_DIR/bar-rsa.crt" -days "$DAYS" -subj "/CN=bar.example.com" -addext "subjectAltName=DNS:bar.example.com" + +openssl genpkey -algorithm Ed25519 -out "$WORK_DIR/bar-ed25519.key" 2>/dev/null; +openssl req -new -x509 -key "$WORK_DIR/bar-ed25519.key" -out "$WORK_DIR/bar-ed25519.crt" -days "$DAYS" -subj "/CN=BAR.EXAMPLE.COM" -addext "subjectAltName=DNS:BAR.EXAMPLE.COM" + +openssl pkcs12 -export -in "$WORK_DIR/foo-rsa.crt" -inkey "$WORK_DIR/foo-rsa.key" -out "$WORK_DIR/foo-rsa.p12" -name foo -passout pass:$PASSWORD +openssl pkcs12 -export -in "$WORK_DIR/foo-ed25519.crt" -inkey "$WORK_DIR/foo-ed25519.key" -out "$WORK_DIR/foo-ed25519.p12" -name foo-ed25519 -passout pass:$PASSWORD +openssl pkcs12 -export -in "$WORK_DIR/bar-rsa.crt" -inkey "$WORK_DIR/bar-rsa.key" -out "$WORK_DIR/bar-rsa.p12" -name bar -passout pass:$PASSWORD +openssl pkcs12 -export -in "$WORK_DIR/bar-ed25519.crt" -inkey "$WORK_DIR/bar-ed25519.key" -out "$WORK_DIR/bar-ed25519.p12" -name bar-ed25519 -passout pass:$PASSWORD + +keytool -importkeystore -noprompt -srckeystore "$WORK_DIR/foo-ed25519.p12" -srcstoretype PKCS12 -srcstorepass $PASSWORD -destkeystore "$SNI_KS" -deststoretype PKCS12 -deststorepass $PASSWORD +keytool -importkeystore -noprompt -srckeystore "$WORK_DIR/foo-rsa.p12" -srcstoretype PKCS12 -srcstorepass $PASSWORD -destkeystore "$SNI_KS" -deststoretype PKCS12 -deststorepass $PASSWORD +keytool -importkeystore -noprompt -srckeystore "$WORK_DIR/bar-ed25519.p12" -srcstoretype PKCS12 -srcstorepass $PASSWORD -destkeystore "$SNI_KS" -deststoretype PKCS12 -deststorepass $PASSWORD +keytool -importkeystore -noprompt -srckeystore "$WORK_DIR/bar-rsa.p12" -srcstoretype PKCS12 -srcstorepass $PASSWORD -destkeystore "$SNI_KS" -deststoretype PKCS12 -deststorepass $PASSWORD + +echo "Wrote 4 certificates for two SNIs to $SNI_KS" diff --git a/service/src/test/resources/org/whispersystems/textsecuregcm/grpc/net/omnibus-h2-server-test-keystore.p12 b/service/src/test/resources/org/whispersystems/textsecuregcm/grpc/net/omnibus-h2-server-test-keystore.p12 new file mode 100644 index 000000000..3e6200dd8 Binary files /dev/null and b/service/src/test/resources/org/whispersystems/textsecuregcm/grpc/net/omnibus-h2-server-test-keystore.p12 differ diff --git a/service/src/test/resources/org/whispersystems/textsecuregcm/grpc/net/sni-mapper-test-keystore.p12 b/service/src/test/resources/org/whispersystems/textsecuregcm/grpc/net/sni-mapper-test-keystore.p12 new file mode 100644 index 000000000..25125da35 Binary files /dev/null and b/service/src/test/resources/org/whispersystems/textsecuregcm/grpc/net/sni-mapper-test-keystore.p12 differ diff --git a/websocket-resources/pom.xml b/websocket-resources/pom.xml index 156ff3dd7..8ed3c7d4f 100644 --- a/websocket-resources/pom.xml +++ b/websocket-resources/pom.xml @@ -13,15 +13,19 @@ org.eclipse.jetty.websocket - websocket-jetty-api + jetty-websocket-jetty-api org.eclipse.jetty.websocket - websocket-jetty-server + jetty-websocket-jetty-server - org.eclipse.jetty.websocket - websocket-servlet + org.eclipse.jetty.ee10.websocket + jetty-ee10-websocket-jetty-server + + + org.eclipse.jetty.ee10.websocket + jetty-ee10-websocket-servlet io.dropwizard diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketClient.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketClient.java index 1bcfd726d..bc1cf9da9 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketClient.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketClient.java @@ -13,9 +13,8 @@ import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; -import org.eclipse.jetty.websocket.api.RemoteEndpoint; +import org.eclipse.jetty.websocket.api.Callback; import org.eclipse.jetty.websocket.api.Session; -import org.eclipse.jetty.websocket.api.WriteCallback; import org.eclipse.jetty.websocket.api.exceptions.WebSocketException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -30,15 +29,13 @@ public class WebSocketClient { private static final SecureRandom SECURE_RANDOM = new SecureRandom(); private final Session session; - private final RemoteEndpoint remoteEndpoint; private final WebSocketMessageFactory messageFactory; private final Map> pendingRequestMapper; private final Instant created; - public WebSocketClient(Session session, RemoteEndpoint remoteEndpoint, WebSocketMessageFactory messageFactory, + public WebSocketClient(Session session, WebSocketMessageFactory messageFactory, Map> pendingRequestMapper) { this.session = session; - this.remoteEndpoint = remoteEndpoint; this.messageFactory = messageFactory; this.pendingRequestMapper = pendingRequestMapper; this.created = Instant.now(); @@ -56,9 +53,9 @@ public class WebSocketClient { WebSocketMessage requestMessage = messageFactory.createRequest(Optional.of(requestId), verb, path, headers, body); try { - remoteEndpoint.sendBytes(ByteBuffer.wrap(requestMessage.toByteArray()), new WriteCallback() { + session.sendBinary(ByteBuffer.wrap(requestMessage.toByteArray()), new Callback() { @Override - public void writeFailed(Throwable x) { + public void fail(Throwable x) { logger.debug("Write failed", x); pendingRequestMapper.remove(requestId); future.completeExceptionally(x); @@ -86,9 +83,9 @@ public class WebSocketClient { } public void close(final int code, final String message) { - session.close(code, message, new WriteCallback() { + session.close(code, message, new Callback() { @Override - public void writeFailed(final Throwable throwable) { + public void fail(final Throwable throwable) { try { session.disconnect(); } catch (final Exception e) { @@ -108,6 +105,6 @@ public class WebSocketClient { } public SocketAddress getRemoteAddress() { - return session.getRemoteAddress(); + return session.getRemoteSocketAddress(); } } diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java index 32262a3e7..245fcb724 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java @@ -16,7 +16,6 @@ import java.net.URI; import java.nio.ByteBuffer; import java.security.Principal; import java.time.Duration; -import java.util.ArrayList; import java.util.HashMap; import java.util.LinkedList; import java.util.List; @@ -26,12 +25,8 @@ import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; -import org.eclipse.jetty.server.Request; -import org.eclipse.jetty.websocket.api.RemoteEndpoint; +import org.eclipse.jetty.websocket.api.Callback; import org.eclipse.jetty.websocket.api.Session; -import org.eclipse.jetty.websocket.api.UpgradeRequest; -import org.eclipse.jetty.websocket.api.WebSocketListener; -import org.eclipse.jetty.websocket.api.WriteCallback; import org.eclipse.jetty.websocket.api.exceptions.MessageTooLargeException; import org.glassfish.jersey.internal.MapPropertiesDelegate; import org.glassfish.jersey.server.ApplicationHandler; @@ -51,11 +46,11 @@ import org.whispersystems.websocket.setup.WebSocketConnectListener; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") -public class WebSocketResourceProvider implements WebSocketListener { +public class WebSocketResourceProvider implements Session.Listener.AutoDemanding { /** * A static exception instance passed to outstanding requests (via {@code completeExceptionally} in - * {@link #onWebSocketClose(int, String)} + * {@link #onWebSocketClose} */ public static final IOException CONNECTION_CLOSED_EXCEPTION = new IOException("Connection closed!"); private static final Logger logger = LoggerFactory.getLogger(WebSocketResourceProvider.class); @@ -73,7 +68,6 @@ public class WebSocketResourceProvider implements WebSocket private final int localPort; private Session session; - private RemoteEndpoint remoteEndpoint; private WebSocketSessionContext context; private static final Set EXCLUDED_UPGRADE_REQUEST_HEADERS = Set.of("connection", "upgrade"); @@ -99,11 +93,10 @@ public class WebSocketResourceProvider implements WebSocket } @Override - public void onWebSocketConnect(Session session) { + public void onWebSocketOpen(Session session) { this.session = session; - this.remoteEndpoint = session.getRemote(); this.context = new WebSocketSessionContext( - new WebSocketClient(session, remoteEndpoint, messageFactory, requestMap)); + new WebSocketClient(session, messageFactory, requestMap)); this.context.setAuthenticated(reusableAuth.orElse(null)); this.session.setIdleTimeout(idleTimeout); @@ -128,9 +121,19 @@ public class WebSocketResourceProvider implements WebSocket } @Override - public void onWebSocketBinary(byte[] payload, int offset, int length) { + public void onWebSocketBinary(final ByteBuffer payload, final Callback callback) { try { - WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length); + onWebSocketBinary(payload); + callback.succeed(); + } catch (RuntimeException e) { + callback.fail(e); + throw e; + } + } + + private void onWebSocketBinary(ByteBuffer payload) { + try { + final WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload); switch (webSocketMessage.getType()) { case REQUEST_MESSAGE: @@ -150,17 +153,21 @@ public class WebSocketResourceProvider implements WebSocket } @Override - public void onWebSocketClose(int statusCode, String reason) { - if (context != null) { - context.notifyClosed(statusCode, reason); + public void onWebSocketClose(int statusCode, String reason, Callback callback) { + try { + if (context != null) { + context.notifyClosed(statusCode, reason); - for (long requestId : requestMap.keySet()) { - CompletableFuture outstandingRequest = requestMap.remove(requestId); + for (long requestId : requestMap.keySet()) { + CompletableFuture outstandingRequest = requestMap.remove(requestId); - if (outstandingRequest != null) { - outstandingRequest.completeExceptionally(CONNECTION_CLOSED_EXCEPTION); + if (outstandingRequest != null) { + outstandingRequest.completeExceptionally(CONNECTION_CLOSED_EXCEPTION); + } } } + } finally { + callback.succeed(); } } @@ -287,7 +294,7 @@ public class WebSocketResourceProvider implements WebSocket } private void close(Session session, int status, String message) { - session.close(status, message); + session.close(status, message, Callback.NOOP); } private void sendResponse(WebSocketRequestMessage requestMessage, ContainerResponse response, @@ -306,7 +313,7 @@ public class WebSocketResourceProvider implements WebSocket Optional.ofNullable(body)) .toByteArray(); - remoteEndpoint.sendBytes(ByteBuffer.wrap(responseBytes), WriteCallback.NOOP); + session.sendBinary(ByteBuffer.wrap(responseBytes), Callback.NOOP); } } @@ -318,7 +325,7 @@ public class WebSocketResourceProvider implements WebSocket getHeaderList(error.getStringHeaders()), Optional.empty()); - remoteEndpoint.sendBytes(ByteBuffer.wrap(response.toByteArray()), WriteCallback.NOOP); + session.sendBinary(ByteBuffer.wrap(response.toByteArray()), Callback.NOOP); } } diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java index 782bd4783..6ce3ba7f7 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java @@ -13,11 +13,9 @@ import java.security.Principal; import java.util.Map; import java.util.Optional; import org.apache.commons.lang3.StringUtils; -import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest; -import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse; -import org.eclipse.jetty.websocket.server.JettyWebSocketCreator; -import org.eclipse.jetty.websocket.server.JettyWebSocketServlet; -import org.eclipse.jetty.websocket.server.JettyWebSocketServletFactory; +import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest; +import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeResponse; +import org.eclipse.jetty.ee10.websocket.server.JettyWebSocketCreator; import org.glassfish.jersey.CommonProperties; import org.glassfish.jersey.server.ApplicationHandler; import org.slf4j.Logger; @@ -25,23 +23,20 @@ import org.slf4j.LoggerFactory; import org.whispersystems.websocket.auth.InvalidCredentialsException; import org.whispersystems.websocket.auth.WebSocketAuthenticator; import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider; -import org.whispersystems.websocket.configuration.WebSocketConfiguration; import org.whispersystems.websocket.session.WebSocketSessionContextValueFactoryProvider; import org.whispersystems.websocket.setup.WebSocketEnvironment; -public class WebSocketResourceProviderFactory extends JettyWebSocketServlet implements - JettyWebSocketCreator { +public class WebSocketResourceProviderFactory implements JettyWebSocketCreator { private static final Logger logger = LoggerFactory.getLogger(WebSocketResourceProviderFactory.class); private final WebSocketEnvironment environment; private final ApplicationHandler jerseyApplicationHandler; - private final WebSocketConfiguration configuration; private final String remoteAddressPropertyName; public WebSocketResourceProviderFactory(WebSocketEnvironment environment, Class principalClass, - WebSocketConfiguration configuration, String remoteAddressPropertyName) { + String remoteAddressPropertyName) { this.environment = environment; environment.jersey().register(new WebSocketSessionContextValueFactoryProvider.Binder()); @@ -55,7 +50,6 @@ public class WebSocketResourceProviderFactory extends Jetty this.jerseyApplicationHandler = new ApplicationHandler(environment.jersey()); - this.configuration = configuration; this.remoteAddressPropertyName = remoteAddressPropertyName; } @@ -91,6 +85,7 @@ public class WebSocketResourceProviderFactory extends Jetty // Authentication may fail for non-incorrect-credential reasons (e.g. we couldn't read from the account database). // If that happens, we don't want to incorrectly tell clients that they provided bad credentials. logger.warn("Authentication failure", e); + try { response.sendError(500, "Failure"); } catch (final IOException ignored) { @@ -99,13 +94,6 @@ public class WebSocketResourceProviderFactory extends Jetty } } - @Override - public void configure(JettyWebSocketServletFactory factory) { - factory.setCreator(this); - factory.setMaxBinaryMessageSize(configuration.getMaxBinaryMessageSize()); - factory.setMaxTextMessageSize(configuration.getMaxTextMessageSize()); - } - private String getRemoteAddress(JettyServerUpgradeRequest request) { final String remoteAddress = (String) request.getHttpServletRequest().getAttribute(remoteAddressPropertyName); if (StringUtils.isBlank(remoteAddress)) { diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/AuthenticatedWebSocketUpgradeFilter.java b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/AuthenticatedWebSocketUpgradeFilter.java index bf0593dc0..f2f49cc9f 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/AuthenticatedWebSocketUpgradeFilter.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/AuthenticatedWebSocketUpgradeFilter.java @@ -7,8 +7,8 @@ package org.whispersystems.websocket.auth; import java.security.Principal; import java.util.Optional; -import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest; -import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse; +import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest; +import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeResponse; public interface AuthenticatedWebSocketUpgradeFilter { diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebSocketAuthenticator.java b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebSocketAuthenticator.java index 7836b5b8d..4974a707e 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebSocketAuthenticator.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebSocketAuthenticator.java @@ -6,7 +6,7 @@ package org.whispersystems.websocket.auth; import java.security.Principal; import java.util.Optional; -import org.eclipse.jetty.websocket.api.UpgradeRequest; +import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest; public interface WebSocketAuthenticator { @@ -20,5 +20,5 @@ public interface WebSocketAuthenticator { * * @throws InvalidCredentialsException if credentials were provided, but could not be authenticated */ - Optional authenticate(UpgradeRequest request) throws InvalidCredentialsException; + Optional authenticate(JettyServerUpgradeRequest request) throws InvalidCredentialsException; } diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/messages/WebSocketMessageFactory.java b/websocket-resources/src/main/java/org/whispersystems/websocket/messages/WebSocketMessageFactory.java index d24a7e09b..bbaeb75cc 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/messages/WebSocketMessageFactory.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/messages/WebSocketMessageFactory.java @@ -5,22 +5,23 @@ package org.whispersystems.websocket.messages; +import java.nio.ByteBuffer; import java.util.List; import java.util.Optional; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") public interface WebSocketMessageFactory { - public WebSocketMessage parseMessage(byte[] serialized, int offset, int len) + WebSocketMessage parseMessage(ByteBuffer serialized) throws InvalidMessageException; - public WebSocketMessage createRequest(Optional requestId, - String verb, String path, - List headers, - Optional body); + WebSocketMessage createRequest(Optional requestId, + String verb, String path, + List headers, + Optional body); - public WebSocketMessage createResponse(long requestId, int status, String message, - List headers, - Optional body); + WebSocketMessage createResponse(long requestId, int status, String message, + List headers, + Optional body); } diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/messages/protobuf/ProtobufWebSocketMessage.java b/websocket-resources/src/main/java/org/whispersystems/websocket/messages/protobuf/ProtobufWebSocketMessage.java index 909673363..8d021cc81 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/messages/protobuf/ProtobufWebSocketMessage.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/messages/protobuf/ProtobufWebSocketMessage.java @@ -10,14 +10,15 @@ import org.whispersystems.websocket.messages.InvalidMessageException; import org.whispersystems.websocket.messages.WebSocketMessage; import org.whispersystems.websocket.messages.WebSocketRequestMessage; import org.whispersystems.websocket.messages.WebSocketResponseMessage; +import java.nio.ByteBuffer; public class ProtobufWebSocketMessage implements WebSocketMessage { private final SubProtocol.WebSocketMessage message; - ProtobufWebSocketMessage(byte[] buffer, int offset, int length) throws InvalidMessageException { + ProtobufWebSocketMessage(ByteBuffer buffer) throws InvalidMessageException { try { - this.message = SubProtocol.WebSocketMessage.parseFrom(ByteString.copyFrom(buffer, offset, length)); + this.message = SubProtocol.WebSocketMessage.parseFrom(ByteString.copyFrom(buffer)); if (getType() == Type.REQUEST_MESSAGE) { if (!message.getRequest().hasVerb() || !message.getRequest().hasPath()) { diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/messages/protobuf/ProtobufWebSocketMessageFactory.java b/websocket-resources/src/main/java/org/whispersystems/websocket/messages/protobuf/ProtobufWebSocketMessageFactory.java index d4f25bd89..33c108771 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/messages/protobuf/ProtobufWebSocketMessageFactory.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/messages/protobuf/ProtobufWebSocketMessageFactory.java @@ -9,16 +9,17 @@ import org.whispersystems.websocket.messages.InvalidMessageException; import org.whispersystems.websocket.messages.WebSocketMessage; import org.whispersystems.websocket.messages.WebSocketMessageFactory; +import java.nio.ByteBuffer; import java.util.List; import java.util.Optional; public class ProtobufWebSocketMessageFactory implements WebSocketMessageFactory { @Override - public WebSocketMessage parseMessage(byte[] serialized, int offset, int len) + public WebSocketMessage parseMessage(ByteBuffer serialized) throws InvalidMessageException { - return new ProtobufWebSocketMessage(serialized, offset, len); + return new ProtobufWebSocketMessage(serialized); } @Override diff --git a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java index 31e44e09f..4cc5e98f4 100644 --- a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java +++ b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java @@ -19,17 +19,15 @@ import java.io.IOException; import java.security.Principal; import java.util.Optional; import javax.security.auth.Subject; +import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest; +import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeResponse; import org.eclipse.jetty.websocket.api.Session; -import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest; -import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse; -import org.eclipse.jetty.websocket.server.JettyWebSocketServletFactory; import org.glassfish.jersey.server.ResourceConfig; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.whispersystems.websocket.auth.InvalidCredentialsException; import org.whispersystems.websocket.auth.AuthenticatedWebSocketUpgradeFilter; +import org.whispersystems.websocket.auth.InvalidCredentialsException; import org.whispersystems.websocket.auth.WebSocketAuthenticator; -import org.whispersystems.websocket.configuration.WebSocketConfiguration; import org.whispersystems.websocket.setup.WebSocketEnvironment; public class WebSocketResourceProviderFactoryTest { @@ -60,8 +58,7 @@ public class WebSocketResourceProviderFactoryTest { when(authenticator.authenticate(eq(request))).thenThrow(new InvalidCredentialsException()); when(environment.jersey()).thenReturn(jerseyEnvironment); - WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory<>(environment, Account.class, - mock(WebSocketConfiguration.class), REMOTE_ADDRESS_PROPERTY_NAME); + WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory<>(environment, Account.class, REMOTE_ADDRESS_PROPERTY_NAME); Object connection = factory.createWebSocket(request, response); assertNull(connection); @@ -80,16 +77,15 @@ public class WebSocketResourceProviderFactoryTest { final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class); when(httpServletRequest.getAttribute(REMOTE_ADDRESS_PROPERTY_NAME)).thenReturn("127.0.0.1"); when(request.getHttpServletRequest()).thenReturn(httpServletRequest); - WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory<>(environment, Account.class, - mock(WebSocketConfiguration.class), REMOTE_ADDRESS_PROPERTY_NAME); + REMOTE_ADDRESS_PROPERTY_NAME); Object connection = factory.createWebSocket(request, response); assertNotNull(connection); verifyNoMoreInteractions(response); verify(authenticator).authenticate(eq(request)); - ((WebSocketResourceProvider) connection).onWebSocketConnect(mock(Session.class)); + ((WebSocketResourceProvider) connection).onWebSocketOpen(mock(Session.class)); assertNotNull(((WebSocketResourceProvider) connection).getContext().getAuthenticated()); assertEquals(((WebSocketResourceProvider) connection).getContext().getAuthenticated(), account); @@ -103,7 +99,6 @@ public class WebSocketResourceProviderFactoryTest { WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory<>(environment, Account.class, - mock(WebSocketConfiguration.class), REMOTE_ADDRESS_PROPERTY_NAME); Object connection = factory.createWebSocket(request, response); @@ -112,20 +107,6 @@ public class WebSocketResourceProviderFactoryTest { verify(authenticator).authenticate(eq(request)); } - @Test - void testConfigure() { - JettyWebSocketServletFactory servletFactory = mock(JettyWebSocketServletFactory.class); - when(environment.jersey()).thenReturn(jerseyEnvironment); - - WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory<>(environment, - Account.class, - mock(WebSocketConfiguration.class), - REMOTE_ADDRESS_PROPERTY_NAME); - factory.configure(servletFactory); - - verify(servletFactory).setCreator(eq(factory)); - } - @Test void testAuthenticatedWebSocketUpgradeFilter() throws InvalidCredentialsException { final Account account = new Account(); @@ -137,12 +118,11 @@ public class WebSocketResourceProviderFactoryTest { final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class); when(httpServletRequest.getAttribute(REMOTE_ADDRESS_PROPERTY_NAME)).thenReturn("127.0.0.1"); when(request.getHttpServletRequest()).thenReturn(httpServletRequest); - final AuthenticatedWebSocketUpgradeFilter filter = mock(AuthenticatedWebSocketUpgradeFilter.class); when(environment.getAuthenticatedWebSocketUpgradeFilter()).thenReturn(filter); final WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory<>(environment, Account.class, - mock(WebSocketConfiguration.class), REMOTE_ADDRESS_PROPERTY_NAME); + REMOTE_ADDRESS_PROPERTY_NAME); assertNotNull(factory.createWebSocket(request, response)); verify(filter).handleAuthentication(reusableAuth, request, response); diff --git a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java index 64738f9ca..bbde7a00d 100644 --- a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java +++ b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java @@ -47,11 +47,9 @@ import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; -import org.eclipse.jetty.websocket.api.CloseStatus; -import org.eclipse.jetty.websocket.api.RemoteEndpoint; +import org.eclipse.jetty.websocket.api.Callback; import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.UpgradeRequest; -import org.eclipse.jetty.websocket.api.WriteCallback; import org.glassfish.jersey.server.ApplicationHandler; import org.glassfish.jersey.server.ContainerRequest; import org.glassfish.jersey.server.ContainerResponse; @@ -92,11 +90,10 @@ class WebSocketResourceProviderTest { when(session.getUpgradeRequest()).thenReturn(request); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); - verify(session, never()).close(anyInt(), anyString()); + verify(session, never()).close(anyInt(), anyString(), any(Callback.class)); verify(session, never()).close(); - verify(session, never()).close(any(CloseStatus.class)); ArgumentCaptor contextArgumentCaptor = ArgumentCaptor.forClass( WebSocketSessionContext.class); @@ -114,11 +111,9 @@ class WebSocketResourceProviderTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); - RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); ContainerResponse response = mock(ContainerResponse.class); when(response.getStatus()).thenReturn(200); @@ -148,16 +143,15 @@ class WebSocketResourceProviderTest { return CompletableFuture.completedFuture(response); }); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); - verify(session, never()).close(anyInt(), anyString()); + verify(session, never()).close(anyInt(), anyString(), any(Callback.class)); verify(session, never()).close(); - verify(session, never()).close(any(CloseStatus.class)); byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/bar", new LinkedList<>(), Optional.of("hello world!".getBytes())).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ContainerRequest.class); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(ByteBuffer.class); @@ -171,7 +165,7 @@ class WebSocketResourceProviderTest { assertThat(bundledRequest.getPath(false)).isEqualTo("bar"); verify(requestLog).log(eq("127.0.0.1"), eq(bundledRequest), eq(response)); - verify(remoteEndpoint).sendBytes(responseCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(responseCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketMessage responseMessageContainer = SubProtocol.WebSocketMessage.parseFrom( responseCaptor.getValue().array()); @@ -191,25 +185,22 @@ class WebSocketResourceProviderTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); - RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); when(applicationHandler.apply(any(ContainerRequest.class), any(OutputStream.class))).thenReturn( CompletableFuture.failedFuture(new IllegalStateException("foo"))); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); - verify(session, never()).close(anyInt(), anyString()); + verify(session, never()).close(anyInt(), anyString(), any(Callback.class)); verify(session, never()).close(); - verify(session, never()).close(any(CloseStatus.class)); byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/bar", new LinkedList<>(), Optional.of("hello world!".getBytes())).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ContainerRequest.class); @@ -223,7 +214,7 @@ class WebSocketResourceProviderTest { ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(responseCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(responseCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketMessage responseMessageContainer = SubProtocol.WebSocketMessage.parseFrom( responseCaptor.getValue().array()); @@ -247,22 +238,20 @@ class WebSocketResourceProviderTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); - RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello", new LinkedList<>(), Optional.empty()).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); @@ -287,22 +276,20 @@ class WebSocketResourceProviderTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); - RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/doesntexist", new LinkedList<>(), Optional.empty()).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); @@ -327,22 +314,20 @@ class WebSocketResourceProviderTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); - RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/world", new LinkedList<>(), Optional.empty()).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); @@ -367,22 +352,20 @@ class WebSocketResourceProviderTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); - RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/world", new LinkedList<>(), Optional.empty()).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); @@ -406,22 +389,20 @@ class WebSocketResourceProviderTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); - RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/optional", new LinkedList<>(), Optional.empty()).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); @@ -446,22 +427,20 @@ class WebSocketResourceProviderTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); - RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/optional", new LinkedList<>(), Optional.empty()).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); @@ -486,23 +465,21 @@ class WebSocketResourceProviderTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); - RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "PUT", "/v1/test/some/testparam", List.of("Content-Type: application/json"), Optional.of(new ObjectMapper().writeValueAsBytes(new TestResource.TestEntity("mykey", 1001)))).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); @@ -527,23 +504,21 @@ class WebSocketResourceProviderTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); - RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "PUT", "/v1/test/some/testparam", List.of("Content-Type: application/json"), Optional.of(new ObjectMapper().writeValueAsBytes(new TestResource.TestEntity("mykey", 5)))).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); @@ -569,22 +544,20 @@ class WebSocketResourceProviderTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); - RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/exception/map", List.of("Content-Type: application/json"), Optional.empty()).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); @@ -609,22 +582,20 @@ class WebSocketResourceProviderTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); - RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/keepalive", new LinkedList<>(), Optional.empty()).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(requestCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(requestCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketRequestMessage requestMessage = getRequest(requestCaptor); assertThat(requestMessage.getVerb()).isEqualTo("GET"); @@ -634,11 +605,11 @@ class WebSocketResourceProviderTest { byte[] clientResponse = new ProtobufWebSocketMessageFactory().createResponse(requestMessage.getId(), 200, "OK", new LinkedList<>(), Optional.of("my response".getBytes())).toByteArray(); - provider.onWebSocketBinary(clientResponse, 0, clientResponse.length); + provider.onWebSocketBinary(ByteBuffer.wrap(clientResponse), Callback.NOOP); ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint, times(2)).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class)); + verify(session, times(2)).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);