Add request/response byte counters to the gRPC metrics interceptor

This commit is contained in:
Jon Chambers 2026-03-24 12:25:00 -04:00 committed by Jon Chambers
parent 8cc0948a34
commit d9d6560b47
2 changed files with 52 additions and 7 deletions

View File

@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.grpc;
import com.google.common.annotations.VisibleForTesting;
import com.google.protobuf.Descriptors;
import com.google.protobuf.GeneratedMessage;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.google.rpc.ErrorInfo;
@ -53,8 +54,12 @@ public class MetricServerInterceptor implements ServerInterceptor {
@VisibleForTesting
static final String REQUEST_MESSAGE_COUNTER_NAME = MetricsUtil.name(MetricServerInterceptor.class, "requestMessage");
@VisibleForTesting
static final String REQUEST_BYTES_COUNTER_NAME = MetricsUtil.name(MetricServerInterceptor.class, "requestBytes");
@VisibleForTesting
static final String RESPONSE_COUNTER_NAME = MetricsUtil.name(MetricServerInterceptor.class, "responseMessage");
@VisibleForTesting
static final String RESPONSE_BYTES_COUNTER_NAME = MetricsUtil.name(MetricServerInterceptor.class, "responseBytes");
@VisibleForTesting
static final String RPC_COUNTER_NAME = MetricsUtil.name(MetricServerInterceptor.class, "rpc");
@VisibleForTesting
static final String DURATION_TIMER_NAME = MetricsUtil.name(MetricServerInterceptor.class, "processingDuration");
@ -99,12 +104,14 @@ public class MetricServerInterceptor implements ServerInterceptor {
private class MetricServerCall<ReqT, RespT> extends ForwardingServerCall.SimpleForwardingServerCall<ReqT, RespT> {
private final Counter responseMessageCounter;
private final Counter responseBytesCounter;
private final Tags tags;
private @Nullable String reason = null;
MetricServerCall(final ServerCall<ReqT, RespT> delegate, final Tags tags) {
super(delegate);
this.responseMessageCounter = meterRegistry.counter(RESPONSE_COUNTER_NAME, tags);
this.responseBytesCounter = meterRegistry.counter(RESPONSE_BYTES_COUNTER_NAME, tags);
this.tags = tags;
}
@ -126,6 +133,11 @@ public class MetricServerInterceptor implements ServerInterceptor {
@Override
public void sendMessage(final RespT responseMessage) {
this.responseMessageCounter.increment();
if (responseMessage instanceof GeneratedMessage generatedMessage) {
this.responseBytesCounter.increment(generatedMessage.getSerializedSize());
}
// Extract the annotated reason (if any) from the message
final String messageReason = MetricServerCall.reason(responseMessage);
@ -171,12 +183,14 @@ public class MetricServerInterceptor implements ServerInterceptor {
private class MetricServerCallListener<ReqT> extends ForwardingServerCallListener.SimpleForwardingServerCallListener<ReqT> {
private final Counter requestCounter;
private final Counter requestBytesCounter;
private final Timer responseTimer;
private final Timer.Sample sample;
MetricServerCallListener(final ServerCall.Listener<ReqT> delegate, final Tags tags) {
super(delegate);
this.requestCounter = meterRegistry.counter(REQUEST_MESSAGE_COUNTER_NAME, tags);
this.requestBytesCounter = meterRegistry.counter(REQUEST_BYTES_COUNTER_NAME, tags);
this.responseTimer = meterRegistry.timer(DURATION_TIMER_NAME, tags);
this.sample = Timer.start(meterRegistry);
}
@ -184,6 +198,11 @@ public class MetricServerInterceptor implements ServerInterceptor {
@Override
public void onMessage(final ReqT requestMessage) {
this.requestCounter.increment();
if (requestMessage instanceof GeneratedMessage generatedMessage) {
this.requestBytesCounter.increment(generatedMessage.getSerializedSize());
}
super.onMessage(requestMessage);
}

View File

@ -28,10 +28,12 @@ import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags;
import io.micrometer.core.instrument.Timer;
import io.micrometer.core.instrument.simple.SimpleMeterRegistry;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Flow;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
@ -100,7 +102,8 @@ public class MetricServerInterceptorTest {
@Test
void unary() {
final EchoServiceGrpc.EchoServiceBlockingStub client = EchoServiceGrpc.newBlockingStub(channel);
client.echo(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("hello")).build());
final EchoRequest request = EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("hello")).build();
final EchoResponse response = client.echo(request);
final Tags commonTags = Tags.of(
"platform", "android",
@ -110,9 +113,15 @@ public class MetricServerInterceptorTest {
final Counter requestCount = find(Counter.class, MetricServerInterceptor.REQUEST_MESSAGE_COUNTER_NAME);
assertThat(requestCount.count()).isCloseTo(1.0, offset(0.01));
final Counter requestByteCount = find(Counter.class, MetricServerInterceptor.REQUEST_BYTES_COUNTER_NAME);
assertThat(requestByteCount.count()).isCloseTo(request.getSerializedSize(), offset(0.01));
final Counter responseCount = find(Counter.class, MetricServerInterceptor.RESPONSE_COUNTER_NAME);
assertThat(responseCount.count()).isCloseTo(1.0, offset(0.01));
final Counter responseByteCount = find(Counter.class, MetricServerInterceptor.RESPONSE_BYTES_COUNTER_NAME);
assertThat(responseByteCount.count()).isCloseTo(response.getSerializedSize(), offset(0.01));
final Counter rpcCount = find(Counter.class, MetricServerInterceptor.RPC_COUNTER_NAME);
assertThat(rpcCount.count()).isCloseTo(1.0, offset(0.01));
@ -132,20 +141,37 @@ public class MetricServerInterceptorTest {
void streaming() throws StatusException, InterruptedException {
final EchoServiceGrpc.EchoServiceBlockingV2Stub client = EchoServiceGrpc.newBlockingV2Stub(channel);
final BlockingClientCall<EchoRequest, EchoResponse> echoStream = client.echoStream();
echoStream.write(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("1")).build());
echoStream.write(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("2")).build());
echoStream.read();
echoStream.read();
final List<EchoRequest> requests = IntStream.range(0, 2)
.mapToObj(i -> EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8(String.valueOf(i))).build())
.toList();
for (final EchoRequest request : requests) {
echoStream.write(request);
}
final List<EchoResponse> responses = new ArrayList<>(requests.size());
for (int i = 0; i < requests.size(); i++) {
responses.add(echoStream.read());
}
echoStream.halfClose();
// Make sure we don't check metrics before our close is processed
channel.shutdown().awaitTermination(5, TimeUnit.SECONDS);
final Counter requestCount = find(Counter.class, MetricServerInterceptor.REQUEST_MESSAGE_COUNTER_NAME);
assertThat(requestCount.count()).isCloseTo(2.0, offset(0.01));
assertThat(requestCount.count()).isCloseTo(requests.size(), offset(0.01));
final Counter requestByteCount = find(Counter.class, MetricServerInterceptor.REQUEST_BYTES_COUNTER_NAME);
assertThat(requestByteCount.count()).isCloseTo(requests.stream().mapToInt(EchoRequest::getSerializedSize).sum(), offset(0.01));
final Counter responseCount = find(Counter.class, MetricServerInterceptor.RESPONSE_COUNTER_NAME);
assertThat(responseCount.count()).isCloseTo(2.0, offset(0.01));
assertThat(responseCount.count()).isCloseTo(responses.size(), offset(0.01));
final Counter responseByteCount = find(Counter.class, MetricServerInterceptor.RESPONSE_BYTES_COUNTER_NAME);
assertThat(responseByteCount.count()).isCloseTo(responses.stream().mapToInt(EchoResponse::getSerializedSize).sum(), offset(0.01));
final Counter rpcCount = find(Counter.class, MetricServerInterceptor.RPC_COUNTER_NAME);
assertThat(rpcCount.count()).isCloseTo(1.0, offset(0.01));