diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/OmnibusH2ServerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/OmnibusH2ServerTest.java index 660ff428c..172b0df8f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/OmnibusH2ServerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/OmnibusH2ServerTest.java @@ -4,7 +4,6 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; -import io.grpc.netty.shaded.io.netty.channel.ChannelOption; import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.ByteBuf; @@ -54,21 +53,24 @@ import java.io.InputStream; import java.net.InetSocketAddress; import java.nio.charset.StandardCharsets; import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import java.util.stream.IntStream; import javax.annotation.Nullable; import javax.net.ssl.SSLException; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import org.junitpioneer.jupiter.cartesian.CartesianTest; class OmnibusH2ServerTest extends AbstractLeakDetectionTest { private static final String KEYSTORE_PASSWORD = "password"; @@ -81,73 +83,65 @@ class OmnibusH2ServerTest extends AbstractLeakDetectionTest { private final NioEventLoopGroup nioEventLoopGroup = new NioEventLoopGroup(); private final DefaultEventLoopGroup localEventLoopGroup = new DefaultEventLoopGroup(); - private Channel defaultBackend; - private Channel prefixBackend; - private CompletableFuture backendConnection; - - private OmnibusH2Server server; + private List backendChannelsToShutDown; + private List omnibusH2ServersToShutDown; @BeforeEach - void setUp() throws Exception { - // Start two H2C backend servers that echo a response with their identity - defaultBackend = startH2CServer(true, DEFAULT_BACKEND_IDENTITY); - prefixBackend = startH2CServer(false, PREFIX_BACKEND_IDENTITY); - backendConnection = new CompletableFuture<>(); - - // self-signed TLS context for the frontend loaded from test keyStore - final InputStream keyStore = OmnibusH2ServerTest.class.getResourceAsStream("omnibus-h2-server-test-keystore.p12"); - - server = new OmnibusH2Server( - SniMapper.buildSniMapping(keyStore, KEYSTORE_PASSWORD), - nioEventLoopGroup, - localEventLoopGroup, - new InetSocketAddress("127.0.0.1", 0), - new OmnibusRouter( - List.of(new OmnibusRouter.OmnibusRoute(PREFIX, prefixBackend.localAddress())), - defaultBackend.localAddress()), - Duration.ofMinutes(1)); - - server.start(); + void setUp() { + backendChannelsToShutDown = new ArrayList<>(); + omnibusH2ServersToShutDown = new ArrayList<>(); } @AfterEach void tearDown() throws Exception { - server.stop(); - defaultBackend.close().sync(); - prefixBackend.close().sync(); + omnibusH2ServersToShutDown.forEach(OmnibusH2Server::stop); + backendChannelsToShutDown.forEach(c -> c.close().syncUninterruptibly()); localEventLoopGroup.shutdownGracefully(1, 1000, TimeUnit.MILLISECONDS).sync(); nioEventLoopGroup.shutdownGracefully(1, 1000, TimeUnit.MILLISECONDS).sync(); } - - @Test - void defaultBackend() { - final String response = sendRequestThroughOmnibus("/a/different/path"); + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void defaultBackend(final boolean localChannel) throws Exception { + final OmnibusH2Server server = startOmnibusServer( + Map.of(PREFIX, startBackendServer(localChannel, PREFIX_BACKEND_IDENTITY)), + startBackendServer(localChannel, DEFAULT_BACKEND_IDENTITY)); + final String response = sendRequestThroughOmnibus(connectToOmnibus(server), "/a/different/path"); assertEquals(DEFAULT_BACKEND_IDENTITY, response); } @ParameterizedTest @ValueSource(booleans = {true, false}) - void forwardedForHeader(final boolean usePpv2) { + void forwardedForHeader(final boolean usePpv2) throws Exception { final String expectedSource = usePpv2 ? "127.0.0.123" : "127.0.0.1"; final HAProxyMessage proxyMessage = usePpv2 ? new HAProxyMessage(HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4, expectedSource, "127.0.0.2", 1234, 5678) : null; - final Channel h2Connection = connectToOmnibus(null, proxyMessage); + final OmnibusH2Server server = startOmnibusServer(startBackendServer(true, DEFAULT_BACKEND_IDENTITY)); + final Channel h2Connection = connectToOmnibus(server, proxyMessage); final String xForwardedFor = sendRequestThroughOmnibus(h2Connection, "/forwarded-for"); assertEquals(expectedSource, xForwardedFor); } - @ParameterizedTest - @ValueSource(strings = {"/v1/prefix", "/v1/prefix/", "/v1/prefix/other"}) - void prefixBackend(final String path) { - final String response = sendRequestThroughOmnibus(path); + @CartesianTest + void prefixBackend( + @CartesianTest.Values(booleans = {true, false}) final boolean localChannel, + @CartesianTest.Values(strings = {"/v1/prefix", "/v1/prefix/", "/v1/prefix/other"}) final String path) throws Exception { + final OmnibusH2Server server = startOmnibusServer( + Map.of(PREFIX, startBackendServer(localChannel, PREFIX_BACKEND_IDENTITY)), + startBackendServer(true, DEFAULT_BACKEND_IDENTITY)); + final String response = sendRequestThroughOmnibus(connectToOmnibus(server), path); assertEquals(PREFIX_BACKEND_IDENTITY, response); } - @Test - void multipleStreamsOnSameConnection() { - final Channel h2Connection = connectToOmnibus(); + @CartesianTest + void multipleStreamsOnSameConnection( + @CartesianTest.Values(booleans = {true, false}) final boolean defaultLocalChannel, + @CartesianTest.Values(booleans = {true, false}) final boolean prefixLocalChannel) throws Exception { + final OmnibusH2Server server = startOmnibusServer( + Map.of(PREFIX, startBackendServer(prefixLocalChannel, PREFIX_BACKEND_IDENTITY)), + startBackendServer(defaultLocalChannel, DEFAULT_BACKEND_IDENTITY)); + final Channel h2Connection = connectToOmnibus(server); final int numStreams = 10; // Create concurrent streams to both backends on the same connection simultaneously @@ -168,12 +162,16 @@ class OmnibusH2ServerTest extends AbstractLeakDetectionTest { } } - @Test - void backendDownStreamReset() { - // Kill the default backend so connection attempts from the omnibus fail - defaultBackend.close().syncUninterruptibly(); + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void backendDownStreamReset(final boolean localChannel) throws Exception { + final Channel backend = startBackendServer(localChannel, DEFAULT_BACKEND_IDENTITY); + final OmnibusH2Server server = startOmnibusServer(backend); - final Channel h2Connection = connectToOmnibus(); + // Kill the default backend so connection attempts from the omnibus fail + backend.close().syncUninterruptibly(); + + final Channel h2Connection = connectToOmnibus(server); final CompletableFuture headersFuture = new CompletableFuture<>(); final Http2StreamChannel stream = new Http2StreamChannelBootstrap(h2Connection) .handler(new HeadersCollectorHandler(headersFuture)) @@ -199,10 +197,13 @@ class OmnibusH2ServerTest extends AbstractLeakDetectionTest { h2Connection.close().syncUninterruptibly(); } - @ParameterizedTest - @ValueSource(strings = {"/goaway", "/reset"}) - void backendCloseClosesClientStream(final String path) { - final Channel h2Connection = connectToOmnibus(); + @CartesianTest + void backendCloseClosesClientStream( + @CartesianTest.Values(booleans = {true, false}) final boolean localChannel, + @CartesianTest.Values(strings = {"/goaway", "/reset"}) final String path) throws Exception { + final OmnibusH2Server server = startOmnibusServer(startBackendServer(localChannel, DEFAULT_BACKEND_IDENTITY)); + final Channel h2Connection = connectToOmnibus(server); + final CompletableFuture resetFuture = new CompletableFuture<>(); final Http2StreamChannel stream = new Http2StreamChannelBootstrap(h2Connection) .handler(new RstCollectorHandler(resetFuture)) @@ -223,9 +224,11 @@ class OmnibusH2ServerTest extends AbstractLeakDetectionTest { assertTrue(h2Connection.isActive()); } - @Test - void queuedDataFrames() throws Exception { - final Channel h2Connection = connectToOmnibus(); + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void queuedDataFrames(final boolean localChannel) throws Exception { + final OmnibusH2Server server = startOmnibusServer(startBackendServer(localChannel, DEFAULT_BACKEND_IDENTITY)); + final Channel h2Connection = connectToOmnibus(server); final CompletableFuture responseFuture = new CompletableFuture<>(); final Http2StreamChannel stream = new Http2StreamChannelBootstrap(h2Connection) .handler(new ResponseCollectorHandler(responseFuture)) @@ -260,9 +263,13 @@ class OmnibusH2ServerTest extends AbstractLeakDetectionTest { h2Connection.close().syncUninterruptibly(); } - @Test - void clientDisconnectClosesBackendConnections() throws Exception { - final Channel h2Connection = connectToOmnibus(); + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void clientDisconnectClosesBackendConnections(final boolean localChannel) throws Exception { + final CompletableFuture backendH2Connection = new CompletableFuture<>(); + final Channel backendServerChannel = startBackendServer(localChannel, DEFAULT_BACKEND_IDENTITY, backendH2Connection::complete, _ -> {}); + final OmnibusH2Server server = startOmnibusServer(backendServerChannel); + final Channel h2Connection = connectToOmnibus(server); final Http2StreamChannelBootstrap streamBootstrap = new Http2StreamChannelBootstrap(h2Connection); final Http2StreamChannel stream = streamBootstrap @@ -278,20 +285,27 @@ class OmnibusH2ServerTest extends AbstractLeakDetectionTest { .scheme("https") .authority("localhost"); stream.writeAndFlush(new DefaultHttp2HeadersFrame(headers, false)).syncUninterruptibly(); - final Channel backendServerChannel = backendConnection.join(); + assertFalse( - backendServerChannel.closeFuture().awaitUninterruptibly(10, TimeUnit.MILLISECONDS), + backendH2Connection.join().closeFuture().awaitUninterruptibly(10, TimeUnit.MILLISECONDS), "Channel should be open"); // All backend connections the omnibus opened on behalf of this client should close if we disconnect the client h2Connection.close().syncUninterruptibly(); - assertTrue(backendServerChannel.closeFuture().await(5, TimeUnit.SECONDS)); + assertTrue(backendH2Connection.join().closeFuture().await(5, TimeUnit.SECONDS)); } - - @Test - void backpressure() throws ExecutionException, InterruptedException, TimeoutException { - final Channel h2Connection = connectToOmnibus(); + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void backpressure(final boolean localChannel) throws Exception { + final AtomicReference backendStreamChannel = new AtomicReference<>(); + final Channel backendServer = startBackendServer(localChannel, "backpressure", _ -> { + }, ch -> { + ch.config().setAutoRead(false); + backendStreamChannel.set(ch); + }); + final OmnibusH2Server omnibusH2Server = startOmnibusServer(backendServer); + final Channel h2Connection = connectToOmnibus(omnibusH2Server, null); // We'll take the client channel becoming unwritable as backpressure signal final AtomicBoolean isWritable = new AtomicBoolean(true); @@ -321,10 +335,7 @@ class OmnibusH2ServerTest extends AbstractLeakDetectionTest { final Http2Headers headers = new DefaultHttp2Headers().method("POST").path("/test"); stream.writeAndFlush(new DefaultHttp2HeadersFrame(headers, false)).syncUninterruptibly(); - // Make the backend H2 stream processor 'slow' by disabling auto-read on it - final Channel backendServerChannel = backendConnection.join(); - backendServerChannel.config().setAutoRead(false); - + final long startNanos = System.nanoTime(); final byte[] chunk = new byte[16384]; do { // Write data until our own client hits the high watermark @@ -341,11 +352,16 @@ class OmnibusH2ServerTest extends AbstractLeakDetectionTest { // more. Eventually all intermediate channels should flush and we should be stuck on the backend channel which // will never make progress (because auto-read is disabled) Thread.sleep(100); + + assertTrue(Duration.ofNanos(System.nanoTime() - startNanos).compareTo(Duration.ofSeconds(10)) < 0, + "Failed to persistently introduce backpressure after 5 seconds. client bytesBeforeUnwritable: " + + stream.bytesBeforeUnwritable() + + " backend bytesBeforeUnwritable: " + backendStreamChannel.get().bytesBeforeUnwritable()); } while (isWritable.get()); stream.writeAndFlush(new DefaultHttp2DataFrame(Unpooled.wrappedBuffer(chunk), true)); // Now re-enable reads on the backend, which should eventually unblock our writes - backendConnection.resultNow().config().setAutoRead(true); + backendStreamChannel.get().config().setAutoRead(true); // Now we should eventually be able to send the last (endStream=true) write and get a response assertEquals("200", response.get(5, TimeUnit.SECONDS).headers().status().toString()); assertTrue(isWritable.get()); @@ -353,21 +369,12 @@ class OmnibusH2ServerTest extends AbstractLeakDetectionTest { h2Connection.close().syncUninterruptibly(); } - @Test - void idleTest() throws Exception { - final InputStream keyStore = OmnibusH2ServerTest.class.getResourceAsStream("omnibus-h2-server-test-keystore.p12"); + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void idleTest(final boolean localChannel) throws Exception { final Duration timeout = Duration.ofMillis(500); - final OmnibusH2Server timeoutServer = new OmnibusH2Server( - SniMapper.buildSniMapping(keyStore, KEYSTORE_PASSWORD), - nioEventLoopGroup, - localEventLoopGroup, - new InetSocketAddress("127.0.0.1", 0), - new OmnibusRouter( - List.of(new OmnibusRouter.OmnibusRoute(PREFIX, prefixBackend.localAddress())), - defaultBackend.localAddress()), - timeout); - timeoutServer.start(); - + final OmnibusH2Server timeoutServer = + startOmnibusServer(Collections.emptyMap(), startBackendServer(localChannel, DEFAULT_BACKEND_IDENTITY), timeout); final Channel channel = connectToOmnibus(timeoutServer, null); // Send a request to make sure idleTimeouts work even after a stream / backend connection has been established @@ -379,41 +386,76 @@ class OmnibusH2ServerTest extends AbstractLeakDetectionTest { timeoutServer.stop(); } - private Channel startH2CServer(final boolean local, final String identity) throws InterruptedException { - final EventLoopGroup eventLoopGroup = local ? localEventLoopGroup : nioEventLoopGroup; + /// Start an OmnibusH2Server. The returned server and provided backends will be torn down in [#tearDown()] + /// + /// @param routes A map of prefixes and the corresponding backend server channels the omnibus will target + /// @param defaultBackend The target backend if no prefix routes match the request path + /// @param timeout The omnibus idle timeout + private OmnibusH2Server startOmnibusServer(final Map routes, final Channel defaultBackend, final Duration timeout) throws Exception { + // self-signed TLS context for the frontend loaded from test keyStore + final InputStream keyStore = OmnibusH2ServerTest.class.getResourceAsStream("omnibus-h2-server-test-keystore.p12"); + + backendChannelsToShutDown.addAll(routes.values()); + backendChannelsToShutDown.add(defaultBackend); + + final OmnibusH2Server server = new OmnibusH2Server( + SniMapper.buildSniMapping(keyStore, KEYSTORE_PASSWORD), + nioEventLoopGroup, + localEventLoopGroup, + new InetSocketAddress("127.0.0.1", 0), + new OmnibusRouter( + routes.entrySet().stream().map(entry -> new OmnibusRouter.OmnibusRoute(entry.getKey(), entry.getValue().localAddress())).toList(), + defaultBackend.localAddress()), + timeout); + server.start(); + omnibusH2ServersToShutDown.add(server); + return server; + } + + private OmnibusH2Server startOmnibusServer(final Channel defaultBackend) throws Exception { + return startOmnibusServer(Collections.emptyMap(), defaultBackend, Duration.ofMinutes(1)); + } + + private OmnibusH2Server startOmnibusServer(final Map routes, final Channel defaultBackend) throws Exception { + return startOmnibusServer(routes, defaultBackend, Duration.ofMinutes(1)); + } + + /// Start a h2c server that can be used as a target of the omnibus + /// + /// @param localChannel whether the omnibus should target this backend via a LocalChannel or an NioChannel + /// @param identity how the backend should respond to identity requests + /// @param h2ChannelInit a Consumer that will be called every time a new HTTP/2 connection is made to this server + /// @param h2StreamInit a Consumer that will be called every time a new HTTP/2 stream is created on this server + private Channel startBackendServer(final boolean localChannel, final String identity, Consumer h2ChannelInit, Consumer h2StreamInit) { + final EventLoopGroup eventLoopGroup = localChannel ? localEventLoopGroup : nioEventLoopGroup; return new ServerBootstrap() .group(eventLoopGroup, eventLoopGroup) - .channel(local ? LocalServerChannel.class : NioServerSocketChannel.class) - // Limit size of kernel TCP buffers to make it easier to hit backpressure in tests - .option(NioChannelOption.SO_RCVBUF, 8192) - .option(NioChannelOption.SO_SNDBUF, 8192) + .channel(localChannel ? LocalServerChannel.class : NioServerSocketChannel.class) .childHandler(new ChannelInitializer<>() { @Override protected void initChannel(final Channel ch) { - backendConnection.complete(ch); + h2ChannelInit.accept(ch); ch.pipeline().addLast(Http2FrameCodecBuilder.forServer().build()); ch.pipeline().addLast(new Http2MultiplexHandler(new ChannelInitializer() { @Override protected void initChannel(final Http2StreamChannel ch) { + h2StreamInit.accept(ch); ch.pipeline().addLast(new TestHandler(identity)); } })); } }) - .bind(local ? new LocalAddress(identity) : new InetSocketAddress("127.0.0.1", 0)) - .sync() + .bind(localChannel ? new LocalAddress(identity) : new InetSocketAddress("127.0.0.1", 0)) + .syncUninterruptibly() .channel(); } - private Channel connectToOmnibus() { - return connectToOmnibus(null, null); + private Channel startBackendServer(final boolean localChannel, final String identity) { + return startBackendServer(localChannel, identity, _ -> {}, _ -> {}); } /// Makes an H2 connection to the omnibus at [this#server] on which new H2 streams can be opened - private Channel connectToOmnibus(@Nullable OmnibusH2Server server, @Nullable final HAProxyMessage proxyHeader) { - if (server == null) { - server = this.server; - } + private Channel connectToOmnibus(final OmnibusH2Server server, @Nullable final HAProxyMessage proxyHeader) { final SslContext clientSsl; try { clientSsl = SslContextBuilder.forClient() @@ -455,12 +497,8 @@ class OmnibusH2ServerTest extends AbstractLeakDetectionTest { return ch; } - /// Sends an H2 request to [this#server] and returns the H2 response body - private String sendRequestThroughOmnibus(final String path) { - final Channel h2Connection = connectToOmnibus(); - final String result = sendRequestThroughOmnibus(h2Connection, path); - h2Connection.close().syncUninterruptibly(); - return result; + private Channel connectToOmnibus(final OmnibusH2Server server) { + return connectToOmnibus(server, null); } private String sendRequestThroughOmnibus(final Channel h2Connection, final String path) {