Update messages integration test to use websocket

This commit is contained in:
Ravi Khadiwala 2026-06-05 12:40:43 -05:00 committed by ravi-signal
parent d69027ce5c
commit d9c39cc12b
4 changed files with 292 additions and 28 deletions

View File

@ -10,6 +10,10 @@
<modelVersion>4.0.0</modelVersion>
<artifactId>integration-tests</artifactId>
<properties>
<jetty.http2-client.version>12.1.5</jetty.http2-client.version>
</properties>
<dependencies>
<dependency>
<groupId>org.whispersystems.textsecure</groupId>
@ -21,6 +25,22 @@
<groupId>software.amazon.awssdk</groupId>
<artifactId>dynamodb</artifactId>
</dependency>
<dependency>
<groupId>org.eclipse.jetty.websocket</groupId>
<artifactId>jetty-websocket-jetty-api</artifactId>
</dependency>
<dependency>
<groupId>org.eclipse.jetty.websocket</groupId>
<artifactId>jetty-websocket-jetty-client</artifactId>
</dependency>
<dependency>
<groupId>org.eclipse.jetty.http2</groupId>
<artifactId>jetty-http2-client-transport</artifactId>
<version>${jetty.http2-client.version}</version>
</dependency>
</dependencies>
<build>

View File

@ -20,18 +20,28 @@ import java.net.URL;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.security.KeyStore;
import java.security.SecureRandom;
import java.security.cert.CertificateException;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.Validate;
import org.apache.commons.lang3.tuple.Pair;
import org.eclipse.jetty.client.HttpClient;
import org.eclipse.jetty.http2.client.HTTP2Client;
import org.eclipse.jetty.http2.client.transport.HttpClientTransportOverHTTP2;
import org.eclipse.jetty.io.ClientConnector;
import org.eclipse.jetty.util.ssl.SslContextFactory;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.signal.integration.config.Config;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
@ -49,6 +59,7 @@ import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.RegistrationRequest;
import org.whispersystems.textsecuregcm.http.FaultTolerantHttpClient;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.CertificateUtil;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.HttpUtils;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@ -65,6 +76,8 @@ public final class Operations {
private static final FaultTolerantHttpClient CLIENT = buildClient();
private static final WebSocketClient WEB_SOCKET_CLIENT = buildWebSocketClient();
private Operations() {
// utility class
@ -218,14 +231,10 @@ public final class Operations {
}
private static <R> RequestBuilder withJsonBody(final String endpoint, final String method, final R input) {
try {
final byte[] body = SystemMapper.jsonMapper().writeValueAsBytes(input);
return new RequestBuilder(HttpRequest.newBuilder()
.header(HttpHeaders.CONTENT_TYPE, "application/json")
.method(method, HttpRequest.BodyPublishers.ofByteArray(body)), endpoint);
} catch (final JsonProcessingException e) {
throw new RuntimeException(e);
}
final byte[] body = encodeJsonBody(input);
return new RequestBuilder(HttpRequest.newBuilder()
.header(HttpHeaders.CONTENT_TYPE, "application/json")
.method(method, HttpRequest.BodyPublishers.ofByteArray(body)), endpoint);
}
public RequestBuilder authorized(final TestUser user) {
@ -305,6 +314,7 @@ public final class Operations {
})
.join();
}
}
private static FaultTolerantHttpClient buildClient() {
@ -317,6 +327,53 @@ public final class Operations {
}
}
private static WebSocketClient buildWebSocketClient() {
try {
final KeyStore trustStore = CertificateUtil.buildKeyStoreForPem(CONFIG.rootCert());
final SslContextFactory.Client sslContextFactory = new SslContextFactory.Client();
sslContextFactory.setTrustStore(trustStore);
final ClientConnector connector = new ClientConnector();
connector.setSslContextFactory(sslContextFactory);
final HTTP2Client http2Client = new HTTP2Client(connector);
final HttpClient httpClient = new HttpClient(new HttpClientTransportOverHTTP2(http2Client));
final WebSocketClient wsClient = new WebSocketClient(httpClient);
wsClient.start();
return wsClient;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
public static WebsocketClientSession authenticatedWebsocket(final TestUser user, final byte deviceId) throws IOException {
final String username = "%s.%d".formatted(user.aciUuid().toString(), deviceId);
return connect("/v1/websocket/", Map.of(HttpHeaders.AUTHORIZATION, HeaderUtils.basicAuthHeader(username, user.accountPassword())));
}
public static WebsocketClientSession anonymousWebsocket() throws IOException {
return connect("/v1/websocket/", Collections.emptyMap());
}
private static WebsocketClientSession connect(
final String path,
final Map<String, String> headers) throws IOException {
final URI uri = URI.create("wss://grpc." + CONFIG.domain() + path);
final ClientUpgradeRequest request = new ClientUpgradeRequest(uri);
headers.forEach(request::setHeader);
final WebsocketClientSession listener = new WebsocketClientSession();
try {
WEB_SOCKET_CLIENT.connect(listener, request).get(5, TimeUnit.SECONDS);
} catch (Exception e) {
throw new IOException(e);
}
logger.info("Successfully connected to websocket on {}", uri);
return listener;
}
private static Config loadConfigFromClasspath(final String filename) {
try {
final URL configFileUrl = Resources.getResource(filename);
@ -345,4 +402,12 @@ public final class Operations {
final byte[] signature = identityKeyPair.getPrivateKey().calculateSignature(pubKey.serialize());
return new KEMSignedPreKey(id, pubKey, signature);
}
public static <R> byte[] encodeJsonBody(final R input) {
try {
return SystemMapper.jsonMapper().writeValueAsBytes(input);
} catch (final JsonProcessingException e) {
throw new RuntimeException(e);
}
}
}

View File

@ -0,0 +1,150 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.integration;
import com.google.protobuf.InvalidProtocolBufferException;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicLong;
import org.eclipse.jetty.websocket.api.Callback;
import org.eclipse.jetty.websocket.api.Session;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.websocket.messages.WebSocketMessage;
import org.whispersystems.websocket.messages.WebSocketMessageFactory;
import org.whispersystems.websocket.messages.WebSocketRequestMessage;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
public class WebsocketClientSession implements Session.Listener.AutoDemanding {
private static final Logger log = LoggerFactory.getLogger(WebsocketClientSession.class);
private final WebSocketMessageFactory messageFactory = new ProtobufWebSocketMessageFactory();
private final AtomicLong requestId = new AtomicLong();
private final ConcurrentHashMap<Long, CompletableFuture<WebSocketResponseMessage>> responseFutures = new ConcurrentHashMap<>();
private final List<MessageProtos.Envelope> receivedEnvelopes = new CopyOnWriteArrayList<>();
private final CompletableFuture<Session> opened = new CompletableFuture<>();
private final CompletableFuture<Void> queueEmpty = new CompletableFuture<>();
private final CompletableFuture<Integer> closed = new CompletableFuture<>();
@Override
public void onWebSocketOpen(final Session session) {
opened.complete(session);
}
@Override
public void onWebSocketBinary(final ByteBuffer payload, final Callback callback) {
try {
final WebSocketMessage message = messageFactory.parseMessage(payload);
switch (message.getType()) {
case REQUEST_MESSAGE -> {
log.info("received request message {} {}", message.getRequestMessage().getVerb(), message.getRequestMessage().getPath());
switch (message.getRequestMessage().getPath()) {
case "/api/v1/message" -> acknowledge(message.getRequestMessage());
case "/api/v1/queue/empty" -> queueEmpty.complete(null);
default -> throw new IllegalStateException("Unexpected path: " + message.getRequestMessage().getPath());
}
}
case RESPONSE_MESSAGE -> {
final WebSocketResponseMessage response = message.getResponseMessage();
log.info("received response message {}", response.getStatus());
final CompletableFuture<WebSocketResponseMessage> future = responseFutures.remove(response.getRequestId());
if (future == null) {
throw new IllegalArgumentException("Received response with no matching request: {}" + response.getRequestId());
}
future.complete(response);
}
default -> throw new IllegalStateException("Unexpected message type: " + message.getType());
}
callback.succeed();
} catch (final Exception e) {
log.warn("Failed to process message received over the websocket", e);
callback.fail(e);
opened.join().close(1006, e.getMessage(), Callback.NOOP);
}
}
@Override
public void onWebSocketClose(final int statusCode, final String reason, final Callback callback) {
log.info("Received websocket close: {}", statusCode);
closed.complete(statusCode);
final IOException exception = new IOException("WebSocket closed: " + statusCode + " " + reason);
responseFutures.values()
.forEach(f -> f.completeExceptionally(exception));
responseFutures.clear();
if (!queueEmpty.isDone()) {
queueEmpty.completeExceptionally(exception);
}
callback.succeed();
}
public <T> WebSocketResponseMessage sendRequest(
final String verb,
final String path,
final List<String> headers,
final T body) {
final Session session = opened.join();
final long id = requestId.incrementAndGet();
final CompletableFuture<WebSocketResponseMessage> future = new CompletableFuture<>();
responseFutures.put(id, future);
final Optional<byte[]> maybeBody = Optional.ofNullable(body).map(Operations::encodeJsonBody);
final byte[] bytes = messageFactory.createRequest(Optional.of(id), verb, path, headers, maybeBody).toByteArray();
session.sendBinary(ByteBuffer.wrap(bytes), Callback.from(() -> {}, throwable -> {
if (responseFutures.remove(id) != null) {
future.completeExceptionally(throwable);
}
}));
return future.join();
}
public List<MessageProtos.Envelope> getReceivedEnvelopes() {
return receivedEnvelopes;
}
public void waitForQueueEmpty() {
queueEmpty.join();
}
public void close(final int closeCode) {
final Session session = opened.join();
session.close(closeCode, "client close", Callback.NOOP);
closed.join();
}
private void acknowledge(WebSocketRequestMessage message) {
final byte[] envelopeBytes = message.getBody()
.orElseThrow(() -> new IllegalStateException("Messages should have a response body"));
try {
final MessageProtos.Envelope envelope = MessageProtos.Envelope.parseFrom(envelopeBytes);
receivedEnvelopes.add(envelope);
final Session session = opened.join();
final WebSocketMessage response = messageFactory.createResponse(message.getRequestId(), 200, "",
Collections.emptyList(), Optional.empty());
session.sendBinary(ByteBuffer.wrap(response.toByteArray()), Callback.NOOP);
} catch (InvalidProtocolBufferException e) {
throw new IllegalStateException(e);
}
}
public static <R> R decode(Class<R> expectedType, WebSocketResponseMessage message) {
try {
return SystemMapper.jsonMapper()
.readValue(message.getBody().orElseThrow(() -> new IllegalStateException("No response body")), expectedType);
} catch (final IOException e) {
throw new UncheckedIOException(e);
}
}
}

View File

@ -6,43 +6,72 @@
package org.signal.integration;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
import com.google.common.net.HttpHeaders;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.List;
import org.apache.commons.lang3.tuple.Pair;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import jakarta.ws.rs.core.MediaType;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.SendMessageResponse;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
@Timeout(value = 1, unit = TimeUnit.MINUTES, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
public class MessagingTest {
TestUser userA;
TestUser userB;
@BeforeEach
public void setup() {
userA = Operations.newRegisteredUser("+19995550102");
userB = Operations.newRegisteredUser("+19995550103");
}
@AfterEach
public void teardown() {
Operations.deleteUser(userA);
Operations.deleteUser(userB);
}
@Test
public void testSendMessageUnsealed() {
final TestUser userA = Operations.newRegisteredUser("+19995550102");
final TestUser userB = Operations.newRegisteredUser("+19995550103");
try {
public void testSendMessageUnsealed() throws IOException {
final byte[] expectedContent = "Hello, World!".getBytes(StandardCharsets.UTF_8);
final IncomingMessage message = new IncomingMessage(1, Device.PRIMARY_ID, userB.registrationId(), expectedContent);
final IncomingMessageList messages = new IncomingMessageList(List.of(message), false, true, System.currentTimeMillis());
Operations
.apiPut("/v1/messages/%s".formatted(userB.aciUuid().toString()), messages)
.authorized(userA)
.execute(SendMessageResponse.class);
final WebsocketClientSession websocketA = Operations.authenticatedWebsocket(userA, Device.PRIMARY_ID);
final WebSocketResponseMessage responseMessage = websocketA.sendRequest(
"PUT",
"/v1/messages/%s".formatted(userB.aciUuid().toString()),
List.of(HttpHeaders.CONTENT_TYPE + ":" + MediaType.APPLICATION_JSON),
messages);
assertEquals(200, responseMessage.getStatus());
assertDoesNotThrow(() -> WebsocketClientSession.decode(SendMessageResponse.class, responseMessage));
final Pair<Integer, OutgoingMessageEntityList> receiveMessages = Operations.apiGet("/v1/messages")
.authorized(userB)
.execute(OutgoingMessageEntityList.class);
final WebsocketClientSession websocketB = Operations.authenticatedWebsocket(userB, Device.PRIMARY_ID);
assertTimeoutPreemptively(Duration.ofSeconds(5), websocketB::waitForQueueEmpty);
final byte[] actualContent = receiveMessages.getRight().messages().getFirst().content();
assertArrayEquals(expectedContent, actualContent);
} finally {
Operations.deleteUser(userA);
Operations.deleteUser(userB);
}
assertEquals(1, websocketB.getReceivedEnvelopes().size());
final MessageProtos.Envelope envelope = websocketB.getReceivedEnvelopes().getFirst();
assertArrayEquals(expectedContent, envelope.getContent().toByteArray());
websocketB.close(1000);
websocketA.close(1000);
}
}