Fix flaky backpressure test

This commit is contained in:
ravi-signal 2026-05-19 15:59:35 -05:00 committed by GitHub
parent b9a24fedea
commit 771fecd396
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,7 +4,6 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import io.grpc.netty.shaded.io.netty.channel.ChannelOption;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
@ -54,21 +53,24 @@ import java.io.InputStream;
import java.net.InetSocketAddress;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.stream.IntStream;
import javax.annotation.Nullable;
import javax.net.ssl.SSLException;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
class OmnibusH2ServerTest extends AbstractLeakDetectionTest {
private static final String KEYSTORE_PASSWORD = "password";
@ -81,73 +83,65 @@ class OmnibusH2ServerTest extends AbstractLeakDetectionTest {
private final NioEventLoopGroup nioEventLoopGroup = new NioEventLoopGroup();
private final DefaultEventLoopGroup localEventLoopGroup = new DefaultEventLoopGroup();
private Channel defaultBackend;
private Channel prefixBackend;
private CompletableFuture<Channel> backendConnection;
private OmnibusH2Server server;
private List<Channel> backendChannelsToShutDown;
private List<OmnibusH2Server> omnibusH2ServersToShutDown;
@BeforeEach
void setUp() throws Exception {
// Start two H2C backend servers that echo a response with their identity
defaultBackend = startH2CServer(true, DEFAULT_BACKEND_IDENTITY);
prefixBackend = startH2CServer(false, PREFIX_BACKEND_IDENTITY);
backendConnection = new CompletableFuture<>();
// self-signed TLS context for the frontend loaded from test keyStore
final InputStream keyStore = OmnibusH2ServerTest.class.getResourceAsStream("omnibus-h2-server-test-keystore.p12");
server = new OmnibusH2Server(
SniMapper.buildSniMapping(keyStore, KEYSTORE_PASSWORD),
nioEventLoopGroup,
localEventLoopGroup,
new InetSocketAddress("127.0.0.1", 0),
new OmnibusRouter(
List.of(new OmnibusRouter.OmnibusRoute(PREFIX, prefixBackend.localAddress())),
defaultBackend.localAddress()),
Duration.ofMinutes(1));
server.start();
void setUp() {
backendChannelsToShutDown = new ArrayList<>();
omnibusH2ServersToShutDown = new ArrayList<>();
}
@AfterEach
void tearDown() throws Exception {
server.stop();
defaultBackend.close().sync();
prefixBackend.close().sync();
omnibusH2ServersToShutDown.forEach(OmnibusH2Server::stop);
backendChannelsToShutDown.forEach(c -> c.close().syncUninterruptibly());
localEventLoopGroup.shutdownGracefully(1, 1000, TimeUnit.MILLISECONDS).sync();
nioEventLoopGroup.shutdownGracefully(1, 1000, TimeUnit.MILLISECONDS).sync();
}
@Test
void defaultBackend() {
final String response = sendRequestThroughOmnibus("/a/different/path");
@ParameterizedTest
@ValueSource(booleans = {true, false})
void defaultBackend(final boolean localChannel) throws Exception {
final OmnibusH2Server server = startOmnibusServer(
Map.of(PREFIX, startBackendServer(localChannel, PREFIX_BACKEND_IDENTITY)),
startBackendServer(localChannel, DEFAULT_BACKEND_IDENTITY));
final String response = sendRequestThroughOmnibus(connectToOmnibus(server), "/a/different/path");
assertEquals(DEFAULT_BACKEND_IDENTITY, response);
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void forwardedForHeader(final boolean usePpv2) {
void forwardedForHeader(final boolean usePpv2) throws Exception {
final String expectedSource = usePpv2 ? "127.0.0.123" : "127.0.0.1";
final HAProxyMessage proxyMessage = usePpv2
? new HAProxyMessage(HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4, expectedSource, "127.0.0.2", 1234, 5678)
: null;
final Channel h2Connection = connectToOmnibus(null, proxyMessage);
final OmnibusH2Server server = startOmnibusServer(startBackendServer(true, DEFAULT_BACKEND_IDENTITY));
final Channel h2Connection = connectToOmnibus(server, proxyMessage);
final String xForwardedFor = sendRequestThroughOmnibus(h2Connection, "/forwarded-for");
assertEquals(expectedSource, xForwardedFor);
}
@ParameterizedTest
@ValueSource(strings = {"/v1/prefix", "/v1/prefix/", "/v1/prefix/other"})
void prefixBackend(final String path) {
final String response = sendRequestThroughOmnibus(path);
@CartesianTest
void prefixBackend(
@CartesianTest.Values(booleans = {true, false}) final boolean localChannel,
@CartesianTest.Values(strings = {"/v1/prefix", "/v1/prefix/", "/v1/prefix/other"}) final String path) throws Exception {
final OmnibusH2Server server = startOmnibusServer(
Map.of(PREFIX, startBackendServer(localChannel, PREFIX_BACKEND_IDENTITY)),
startBackendServer(true, DEFAULT_BACKEND_IDENTITY));
final String response = sendRequestThroughOmnibus(connectToOmnibus(server), path);
assertEquals(PREFIX_BACKEND_IDENTITY, response);
}
@Test
void multipleStreamsOnSameConnection() {
final Channel h2Connection = connectToOmnibus();
@CartesianTest
void multipleStreamsOnSameConnection(
@CartesianTest.Values(booleans = {true, false}) final boolean defaultLocalChannel,
@CartesianTest.Values(booleans = {true, false}) final boolean prefixLocalChannel) throws Exception {
final OmnibusH2Server server = startOmnibusServer(
Map.of(PREFIX, startBackendServer(prefixLocalChannel, PREFIX_BACKEND_IDENTITY)),
startBackendServer(defaultLocalChannel, DEFAULT_BACKEND_IDENTITY));
final Channel h2Connection = connectToOmnibus(server);
final int numStreams = 10;
// Create concurrent streams to both backends on the same connection simultaneously
@ -168,12 +162,16 @@ class OmnibusH2ServerTest extends AbstractLeakDetectionTest {
}
}
@Test
void backendDownStreamReset() {
// Kill the default backend so connection attempts from the omnibus fail
defaultBackend.close().syncUninterruptibly();
@ParameterizedTest
@ValueSource(booleans = {true, false})
void backendDownStreamReset(final boolean localChannel) throws Exception {
final Channel backend = startBackendServer(localChannel, DEFAULT_BACKEND_IDENTITY);
final OmnibusH2Server server = startOmnibusServer(backend);
final Channel h2Connection = connectToOmnibus();
// Kill the default backend so connection attempts from the omnibus fail
backend.close().syncUninterruptibly();
final Channel h2Connection = connectToOmnibus(server);
final CompletableFuture<Http2HeadersFrame> headersFuture = new CompletableFuture<>();
final Http2StreamChannel stream = new Http2StreamChannelBootstrap(h2Connection)
.handler(new HeadersCollectorHandler(headersFuture))
@ -199,10 +197,13 @@ class OmnibusH2ServerTest extends AbstractLeakDetectionTest {
h2Connection.close().syncUninterruptibly();
}
@ParameterizedTest
@ValueSource(strings = {"/goaway", "/reset"})
void backendCloseClosesClientStream(final String path) {
final Channel h2Connection = connectToOmnibus();
@CartesianTest
void backendCloseClosesClientStream(
@CartesianTest.Values(booleans = {true, false}) final boolean localChannel,
@CartesianTest.Values(strings = {"/goaway", "/reset"}) final String path) throws Exception {
final OmnibusH2Server server = startOmnibusServer(startBackendServer(localChannel, DEFAULT_BACKEND_IDENTITY));
final Channel h2Connection = connectToOmnibus(server);
final CompletableFuture<Http2ResetFrame> resetFuture = new CompletableFuture<>();
final Http2StreamChannel stream = new Http2StreamChannelBootstrap(h2Connection)
.handler(new RstCollectorHandler(resetFuture))
@ -223,9 +224,11 @@ class OmnibusH2ServerTest extends AbstractLeakDetectionTest {
assertTrue(h2Connection.isActive());
}
@Test
void queuedDataFrames() throws Exception {
final Channel h2Connection = connectToOmnibus();
@ParameterizedTest
@ValueSource(booleans = {true, false})
void queuedDataFrames(final boolean localChannel) throws Exception {
final OmnibusH2Server server = startOmnibusServer(startBackendServer(localChannel, DEFAULT_BACKEND_IDENTITY));
final Channel h2Connection = connectToOmnibus(server);
final CompletableFuture<String> responseFuture = new CompletableFuture<>();
final Http2StreamChannel stream = new Http2StreamChannelBootstrap(h2Connection)
.handler(new ResponseCollectorHandler(responseFuture))
@ -260,9 +263,13 @@ class OmnibusH2ServerTest extends AbstractLeakDetectionTest {
h2Connection.close().syncUninterruptibly();
}
@Test
void clientDisconnectClosesBackendConnections() throws Exception {
final Channel h2Connection = connectToOmnibus();
@ParameterizedTest
@ValueSource(booleans = {true, false})
void clientDisconnectClosesBackendConnections(final boolean localChannel) throws Exception {
final CompletableFuture<Channel> backendH2Connection = new CompletableFuture<>();
final Channel backendServerChannel = startBackendServer(localChannel, DEFAULT_BACKEND_IDENTITY, backendH2Connection::complete, _ -> {});
final OmnibusH2Server server = startOmnibusServer(backendServerChannel);
final Channel h2Connection = connectToOmnibus(server);
final Http2StreamChannelBootstrap streamBootstrap = new Http2StreamChannelBootstrap(h2Connection);
final Http2StreamChannel stream = streamBootstrap
@ -278,20 +285,27 @@ class OmnibusH2ServerTest extends AbstractLeakDetectionTest {
.scheme("https")
.authority("localhost");
stream.writeAndFlush(new DefaultHttp2HeadersFrame(headers, false)).syncUninterruptibly();
final Channel backendServerChannel = backendConnection.join();
assertFalse(
backendServerChannel.closeFuture().awaitUninterruptibly(10, TimeUnit.MILLISECONDS),
backendH2Connection.join().closeFuture().awaitUninterruptibly(10, TimeUnit.MILLISECONDS),
"Channel should be open");
// All backend connections the omnibus opened on behalf of this client should close if we disconnect the client
h2Connection.close().syncUninterruptibly();
assertTrue(backendServerChannel.closeFuture().await(5, TimeUnit.SECONDS));
assertTrue(backendH2Connection.join().closeFuture().await(5, TimeUnit.SECONDS));
}
@Test
void backpressure() throws ExecutionException, InterruptedException, TimeoutException {
final Channel h2Connection = connectToOmnibus();
@ParameterizedTest
@ValueSource(booleans = {true, false})
void backpressure(final boolean localChannel) throws Exception {
final AtomicReference<Channel> backendStreamChannel = new AtomicReference<>();
final Channel backendServer = startBackendServer(localChannel, "backpressure", _ -> {
}, ch -> {
ch.config().setAutoRead(false);
backendStreamChannel.set(ch);
});
final OmnibusH2Server omnibusH2Server = startOmnibusServer(backendServer);
final Channel h2Connection = connectToOmnibus(omnibusH2Server, null);
// We'll take the client channel becoming unwritable as backpressure signal
final AtomicBoolean isWritable = new AtomicBoolean(true);
@ -321,10 +335,7 @@ class OmnibusH2ServerTest extends AbstractLeakDetectionTest {
final Http2Headers headers = new DefaultHttp2Headers().method("POST").path("/test");
stream.writeAndFlush(new DefaultHttp2HeadersFrame(headers, false)).syncUninterruptibly();
// Make the backend H2 stream processor 'slow' by disabling auto-read on it
final Channel backendServerChannel = backendConnection.join();
backendServerChannel.config().setAutoRead(false);
final long startNanos = System.nanoTime();
final byte[] chunk = new byte[16384];
do {
// Write data until our own client hits the high watermark
@ -341,11 +352,16 @@ class OmnibusH2ServerTest extends AbstractLeakDetectionTest {
// more. Eventually all intermediate channels should flush and we should be stuck on the backend channel which
// will never make progress (because auto-read is disabled)
Thread.sleep(100);
assertTrue(Duration.ofNanos(System.nanoTime() - startNanos).compareTo(Duration.ofSeconds(10)) < 0,
"Failed to persistently introduce backpressure after 5 seconds. client bytesBeforeUnwritable: "
+ stream.bytesBeforeUnwritable() +
" backend bytesBeforeUnwritable: " + backendStreamChannel.get().bytesBeforeUnwritable());
} while (isWritable.get());
stream.writeAndFlush(new DefaultHttp2DataFrame(Unpooled.wrappedBuffer(chunk), true));
// Now re-enable reads on the backend, which should eventually unblock our writes
backendConnection.resultNow().config().setAutoRead(true);
backendStreamChannel.get().config().setAutoRead(true);
// Now we should eventually be able to send the last (endStream=true) write and get a response
assertEquals("200", response.get(5, TimeUnit.SECONDS).headers().status().toString());
assertTrue(isWritable.get());
@ -353,21 +369,12 @@ class OmnibusH2ServerTest extends AbstractLeakDetectionTest {
h2Connection.close().syncUninterruptibly();
}
@Test
void idleTest() throws Exception {
final InputStream keyStore = OmnibusH2ServerTest.class.getResourceAsStream("omnibus-h2-server-test-keystore.p12");
@ParameterizedTest
@ValueSource(booleans = {true, false})
void idleTest(final boolean localChannel) throws Exception {
final Duration timeout = Duration.ofMillis(500);
final OmnibusH2Server timeoutServer = new OmnibusH2Server(
SniMapper.buildSniMapping(keyStore, KEYSTORE_PASSWORD),
nioEventLoopGroup,
localEventLoopGroup,
new InetSocketAddress("127.0.0.1", 0),
new OmnibusRouter(
List.of(new OmnibusRouter.OmnibusRoute(PREFIX, prefixBackend.localAddress())),
defaultBackend.localAddress()),
timeout);
timeoutServer.start();
final OmnibusH2Server timeoutServer =
startOmnibusServer(Collections.emptyMap(), startBackendServer(localChannel, DEFAULT_BACKEND_IDENTITY), timeout);
final Channel channel = connectToOmnibus(timeoutServer, null);
// Send a request to make sure idleTimeouts work even after a stream / backend connection has been established
@ -379,41 +386,76 @@ class OmnibusH2ServerTest extends AbstractLeakDetectionTest {
timeoutServer.stop();
}
private Channel startH2CServer(final boolean local, final String identity) throws InterruptedException {
final EventLoopGroup eventLoopGroup = local ? localEventLoopGroup : nioEventLoopGroup;
/// Start an OmnibusH2Server. The returned server and provided backends will be torn down in [#tearDown()]
///
/// @param routes A map of prefixes and the corresponding backend server channels the omnibus will target
/// @param defaultBackend The target backend if no prefix routes match the request path
/// @param timeout The omnibus idle timeout
private OmnibusH2Server startOmnibusServer(final Map<String, Channel> routes, final Channel defaultBackend, final Duration timeout) throws Exception {
// self-signed TLS context for the frontend loaded from test keyStore
final InputStream keyStore = OmnibusH2ServerTest.class.getResourceAsStream("omnibus-h2-server-test-keystore.p12");
backendChannelsToShutDown.addAll(routes.values());
backendChannelsToShutDown.add(defaultBackend);
final OmnibusH2Server server = new OmnibusH2Server(
SniMapper.buildSniMapping(keyStore, KEYSTORE_PASSWORD),
nioEventLoopGroup,
localEventLoopGroup,
new InetSocketAddress("127.0.0.1", 0),
new OmnibusRouter(
routes.entrySet().stream().map(entry -> new OmnibusRouter.OmnibusRoute(entry.getKey(), entry.getValue().localAddress())).toList(),
defaultBackend.localAddress()),
timeout);
server.start();
omnibusH2ServersToShutDown.add(server);
return server;
}
private OmnibusH2Server startOmnibusServer(final Channel defaultBackend) throws Exception {
return startOmnibusServer(Collections.emptyMap(), defaultBackend, Duration.ofMinutes(1));
}
private OmnibusH2Server startOmnibusServer(final Map<String, Channel> routes, final Channel defaultBackend) throws Exception {
return startOmnibusServer(routes, defaultBackend, Duration.ofMinutes(1));
}
/// Start a h2c server that can be used as a target of the omnibus
///
/// @param localChannel whether the omnibus should target this backend via a LocalChannel or an NioChannel
/// @param identity how the backend should respond to identity requests
/// @param h2ChannelInit a Consumer that will be called every time a new HTTP/2 connection is made to this server
/// @param h2StreamInit a Consumer that will be called every time a new HTTP/2 stream is created on this server
private Channel startBackendServer(final boolean localChannel, final String identity, Consumer<Channel> h2ChannelInit, Consumer<Channel> h2StreamInit) {
final EventLoopGroup eventLoopGroup = localChannel ? localEventLoopGroup : nioEventLoopGroup;
return new ServerBootstrap()
.group(eventLoopGroup, eventLoopGroup)
.channel(local ? LocalServerChannel.class : NioServerSocketChannel.class)
// Limit size of kernel TCP buffers to make it easier to hit backpressure in tests
.option(NioChannelOption.SO_RCVBUF, 8192)
.option(NioChannelOption.SO_SNDBUF, 8192)
.channel(localChannel ? LocalServerChannel.class : NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<>() {
@Override
protected void initChannel(final Channel ch) {
backendConnection.complete(ch);
h2ChannelInit.accept(ch);
ch.pipeline().addLast(Http2FrameCodecBuilder.forServer().build());
ch.pipeline().addLast(new Http2MultiplexHandler(new ChannelInitializer<Http2StreamChannel>() {
@Override
protected void initChannel(final Http2StreamChannel ch) {
h2StreamInit.accept(ch);
ch.pipeline().addLast(new TestHandler(identity));
}
}));
}
})
.bind(local ? new LocalAddress(identity) : new InetSocketAddress("127.0.0.1", 0))
.sync()
.bind(localChannel ? new LocalAddress(identity) : new InetSocketAddress("127.0.0.1", 0))
.syncUninterruptibly()
.channel();
}
private Channel connectToOmnibus() {
return connectToOmnibus(null, null);
private Channel startBackendServer(final boolean localChannel, final String identity) {
return startBackendServer(localChannel, identity, _ -> {}, _ -> {});
}
/// Makes an H2 connection to the omnibus at [this#server] on which new H2 streams can be opened
private Channel connectToOmnibus(@Nullable OmnibusH2Server server, @Nullable final HAProxyMessage proxyHeader) {
if (server == null) {
server = this.server;
}
private Channel connectToOmnibus(final OmnibusH2Server server, @Nullable final HAProxyMessage proxyHeader) {
final SslContext clientSsl;
try {
clientSsl = SslContextBuilder.forClient()
@ -455,12 +497,8 @@ class OmnibusH2ServerTest extends AbstractLeakDetectionTest {
return ch;
}
/// Sends an H2 request to [this#server] and returns the H2 response body
private String sendRequestThroughOmnibus(final String path) {
final Channel h2Connection = connectToOmnibus();
final String result = sendRequestThroughOmnibus(h2Connection, path);
h2Connection.close().syncUninterruptibly();
return result;
private Channel connectToOmnibus(final OmnibusH2Server server) {
return connectToOmnibus(server, null);
}
private String sendRequestThroughOmnibus(final Channel h2Connection, final String path) {