diff --git a/integration-tests/pom.xml b/integration-tests/pom.xml index cd2b18d7b..bc5af0a6c 100644 --- a/integration-tests/pom.xml +++ b/integration-tests/pom.xml @@ -10,6 +10,10 @@ 4.0.0 integration-tests + + 12.1.5 + + org.whispersystems.textsecure @@ -21,6 +25,22 @@ software.amazon.awssdk dynamodb + + + org.eclipse.jetty.websocket + jetty-websocket-jetty-api + + + + org.eclipse.jetty.websocket + jetty-websocket-jetty-client + + + + org.eclipse.jetty.http2 + jetty-http2-client-transport + ${jetty.http2-client.version} + diff --git a/integration-tests/src/main/java/org/signal/integration/Operations.java b/integration-tests/src/main/java/org/signal/integration/Operations.java index 8fcc758a1..47b8774df 100644 --- a/integration-tests/src/main/java/org/signal/integration/Operations.java +++ b/integration-tests/src/main/java/org/signal/integration/Operations.java @@ -20,18 +20,28 @@ import java.net.URL; import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.nio.charset.StandardCharsets; +import java.security.KeyStore; import java.security.SecureRandom; import java.security.cert.CertificateException; import java.util.ArrayList; import java.util.Base64; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.Validate; import org.apache.commons.lang3.tuple.Pair; +import org.eclipse.jetty.client.HttpClient; +import org.eclipse.jetty.http2.client.HTTP2Client; +import org.eclipse.jetty.http2.client.transport.HttpClientTransportOverHTTP2; +import org.eclipse.jetty.io.ClientConnector; +import org.eclipse.jetty.util.ssl.SslContextFactory; +import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; +import org.eclipse.jetty.websocket.client.WebSocketClient; import org.signal.integration.config.Config; import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ecc.ECKeyPair; @@ -49,6 +59,7 @@ import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.RegistrationRequest; import org.whispersystems.textsecuregcm.http.FaultTolerantHttpClient; import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.util.CertificateUtil; import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.HttpUtils; import org.whispersystems.textsecuregcm.util.SystemMapper; @@ -65,6 +76,8 @@ public final class Operations { private static final FaultTolerantHttpClient CLIENT = buildClient(); + private static final WebSocketClient WEB_SOCKET_CLIENT = buildWebSocketClient(); + private Operations() { // utility class @@ -218,14 +231,10 @@ public final class Operations { } private static RequestBuilder withJsonBody(final String endpoint, final String method, final R input) { - try { - final byte[] body = SystemMapper.jsonMapper().writeValueAsBytes(input); - return new RequestBuilder(HttpRequest.newBuilder() - .header(HttpHeaders.CONTENT_TYPE, "application/json") - .method(method, HttpRequest.BodyPublishers.ofByteArray(body)), endpoint); - } catch (final JsonProcessingException e) { - throw new RuntimeException(e); - } + final byte[] body = encodeJsonBody(input); + return new RequestBuilder(HttpRequest.newBuilder() + .header(HttpHeaders.CONTENT_TYPE, "application/json") + .method(method, HttpRequest.BodyPublishers.ofByteArray(body)), endpoint); } public RequestBuilder authorized(final TestUser user) { @@ -305,6 +314,7 @@ public final class Operations { }) .join(); } + } private static FaultTolerantHttpClient buildClient() { @@ -317,6 +327,53 @@ public final class Operations { } } + private static WebSocketClient buildWebSocketClient() { + try { + final KeyStore trustStore = CertificateUtil.buildKeyStoreForPem(CONFIG.rootCert()); + final SslContextFactory.Client sslContextFactory = new SslContextFactory.Client(); + sslContextFactory.setTrustStore(trustStore); + + final ClientConnector connector = new ClientConnector(); + connector.setSslContextFactory(sslContextFactory); + + final HTTP2Client http2Client = new HTTP2Client(connector); + final HttpClient httpClient = new HttpClient(new HttpClientTransportOverHTTP2(http2Client)); + + final WebSocketClient wsClient = new WebSocketClient(httpClient); + wsClient.start(); + return wsClient; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + public static WebsocketClientSession authenticatedWebsocket(final TestUser user, final byte deviceId) throws IOException { + final String username = "%s.%d".formatted(user.aciUuid().toString(), deviceId); + return connect("/v1/websocket/", Map.of(HttpHeaders.AUTHORIZATION, HeaderUtils.basicAuthHeader(username, user.accountPassword()))); + } + + public static WebsocketClientSession anonymousWebsocket() throws IOException { + return connect("/v1/websocket/", Collections.emptyMap()); + } + + private static WebsocketClientSession connect( + final String path, + final Map headers) throws IOException { + + final URI uri = URI.create("wss://grpc." + CONFIG.domain() + path); + final ClientUpgradeRequest request = new ClientUpgradeRequest(uri); + headers.forEach(request::setHeader); + + final WebsocketClientSession listener = new WebsocketClientSession(); + try { + WEB_SOCKET_CLIENT.connect(listener, request).get(5, TimeUnit.SECONDS); + } catch (Exception e) { + throw new IOException(e); + } + logger.info("Successfully connected to websocket on {}", uri); + return listener; + } + private static Config loadConfigFromClasspath(final String filename) { try { final URL configFileUrl = Resources.getResource(filename); @@ -345,4 +402,12 @@ public final class Operations { final byte[] signature = identityKeyPair.getPrivateKey().calculateSignature(pubKey.serialize()); return new KEMSignedPreKey(id, pubKey, signature); } + + public static byte[] encodeJsonBody(final R input) { + try { + return SystemMapper.jsonMapper().writeValueAsBytes(input); + } catch (final JsonProcessingException e) { + throw new RuntimeException(e); + } + } } diff --git a/integration-tests/src/main/java/org/signal/integration/WebsocketClientSession.java b/integration-tests/src/main/java/org/signal/integration/WebsocketClientSession.java new file mode 100644 index 000000000..7ab5259fd --- /dev/null +++ b/integration-tests/src/main/java/org/signal/integration/WebsocketClientSession.java @@ -0,0 +1,150 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.signal.integration; + +import com.google.protobuf.InvalidProtocolBufferException; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.atomic.AtomicLong; +import org.eclipse.jetty.websocket.api.Callback; +import org.eclipse.jetty.websocket.api.Session; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.util.SystemMapper; +import org.whispersystems.websocket.messages.WebSocketMessage; +import org.whispersystems.websocket.messages.WebSocketMessageFactory; +import org.whispersystems.websocket.messages.WebSocketRequestMessage; +import org.whispersystems.websocket.messages.WebSocketResponseMessage; +import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory; + +public class WebsocketClientSession implements Session.Listener.AutoDemanding { + + private static final Logger log = LoggerFactory.getLogger(WebsocketClientSession.class); + private final WebSocketMessageFactory messageFactory = new ProtobufWebSocketMessageFactory(); + private final AtomicLong requestId = new AtomicLong(); + private final ConcurrentHashMap> responseFutures = new ConcurrentHashMap<>(); + private final List receivedEnvelopes = new CopyOnWriteArrayList<>(); + private final CompletableFuture opened = new CompletableFuture<>(); + private final CompletableFuture queueEmpty = new CompletableFuture<>(); + private final CompletableFuture closed = new CompletableFuture<>(); + + @Override + public void onWebSocketOpen(final Session session) { + opened.complete(session); + } + + + @Override + public void onWebSocketBinary(final ByteBuffer payload, final Callback callback) { + try { + final WebSocketMessage message = messageFactory.parseMessage(payload); + switch (message.getType()) { + case REQUEST_MESSAGE -> { + log.info("received request message {} {}", message.getRequestMessage().getVerb(), message.getRequestMessage().getPath()); + switch (message.getRequestMessage().getPath()) { + case "/api/v1/message" -> acknowledge(message.getRequestMessage()); + case "/api/v1/queue/empty" -> queueEmpty.complete(null); + default -> throw new IllegalStateException("Unexpected path: " + message.getRequestMessage().getPath()); + } + } + case RESPONSE_MESSAGE -> { + final WebSocketResponseMessage response = message.getResponseMessage(); + log.info("received response message {}", response.getStatus()); + final CompletableFuture future = responseFutures.remove(response.getRequestId()); + if (future == null) { + throw new IllegalArgumentException("Received response with no matching request: {}" + response.getRequestId()); + } + future.complete(response); + } + default -> throw new IllegalStateException("Unexpected message type: " + message.getType()); + } + callback.succeed(); + } catch (final Exception e) { + log.warn("Failed to process message received over the websocket", e); + callback.fail(e); + opened.join().close(1006, e.getMessage(), Callback.NOOP); + } + } + + @Override + public void onWebSocketClose(final int statusCode, final String reason, final Callback callback) { + log.info("Received websocket close: {}", statusCode); + closed.complete(statusCode); + final IOException exception = new IOException("WebSocket closed: " + statusCode + " " + reason); + responseFutures.values() + .forEach(f -> f.completeExceptionally(exception)); + responseFutures.clear(); + if (!queueEmpty.isDone()) { + queueEmpty.completeExceptionally(exception); + } + callback.succeed(); + } + + public WebSocketResponseMessage sendRequest( + final String verb, + final String path, + final List headers, + final T body) { + final Session session = opened.join(); + final long id = requestId.incrementAndGet(); + final CompletableFuture future = new CompletableFuture<>(); + responseFutures.put(id, future); + final Optional maybeBody = Optional.ofNullable(body).map(Operations::encodeJsonBody); + final byte[] bytes = messageFactory.createRequest(Optional.of(id), verb, path, headers, maybeBody).toByteArray(); + session.sendBinary(ByteBuffer.wrap(bytes), Callback.from(() -> {}, throwable -> { + if (responseFutures.remove(id) != null) { + future.completeExceptionally(throwable); + } + })); + return future.join(); + } + + public List getReceivedEnvelopes() { + return receivedEnvelopes; + } + + public void waitForQueueEmpty() { + queueEmpty.join(); + } + + public void close(final int closeCode) { + final Session session = opened.join(); + session.close(closeCode, "client close", Callback.NOOP); + closed.join(); + } + + private void acknowledge(WebSocketRequestMessage message) { + final byte[] envelopeBytes = message.getBody() + .orElseThrow(() -> new IllegalStateException("Messages should have a response body")); + try { + final MessageProtos.Envelope envelope = MessageProtos.Envelope.parseFrom(envelopeBytes); + receivedEnvelopes.add(envelope); + final Session session = opened.join(); + final WebSocketMessage response = messageFactory.createResponse(message.getRequestId(), 200, "", + Collections.emptyList(), Optional.empty()); + session.sendBinary(ByteBuffer.wrap(response.toByteArray()), Callback.NOOP); + + } catch (InvalidProtocolBufferException e) { + throw new IllegalStateException(e); + } + } + + public static R decode(Class expectedType, WebSocketResponseMessage message) { + try { + return SystemMapper.jsonMapper() + .readValue(message.getBody().orElseThrow(() -> new IllegalStateException("No response body")), expectedType); + } catch (final IOException e) { + throw new UncheckedIOException(e); + } + } +} diff --git a/integration-tests/src/test/java/org/signal/integration/MessagingTest.java b/integration-tests/src/test/java/org/signal/integration/MessagingTest.java index cf32b3a8a..a5e49e7a5 100644 --- a/integration-tests/src/test/java/org/signal/integration/MessagingTest.java +++ b/integration-tests/src/test/java/org/signal/integration/MessagingTest.java @@ -6,43 +6,72 @@ package org.signal.integration; import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; +import com.google.common.net.HttpHeaders; +import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.time.Duration; import java.util.List; -import org.apache.commons.lang3.tuple.Pair; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import jakarta.ws.rs.core.MediaType; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessageList; -import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; +import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.SendMessageResponse; import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.websocket.messages.WebSocketResponseMessage; +@Timeout(value = 1, unit = TimeUnit.MINUTES, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) public class MessagingTest { + TestUser userA; + TestUser userB; + + @BeforeEach + public void setup() { + userA = Operations.newRegisteredUser("+19995550102"); + userB = Operations.newRegisteredUser("+19995550103"); + } + + @AfterEach + public void teardown() { + Operations.deleteUser(userA); + Operations.deleteUser(userB); + } @Test - public void testSendMessageUnsealed() { - final TestUser userA = Operations.newRegisteredUser("+19995550102"); - final TestUser userB = Operations.newRegisteredUser("+19995550103"); - - try { + public void testSendMessageUnsealed() throws IOException { final byte[] expectedContent = "Hello, World!".getBytes(StandardCharsets.UTF_8); final IncomingMessage message = new IncomingMessage(1, Device.PRIMARY_ID, userB.registrationId(), expectedContent); final IncomingMessageList messages = new IncomingMessageList(List.of(message), false, true, System.currentTimeMillis()); - Operations - .apiPut("/v1/messages/%s".formatted(userB.aciUuid().toString()), messages) - .authorized(userA) - .execute(SendMessageResponse.class); + final WebsocketClientSession websocketA = Operations.authenticatedWebsocket(userA, Device.PRIMARY_ID); + final WebSocketResponseMessage responseMessage = websocketA.sendRequest( + "PUT", + "/v1/messages/%s".formatted(userB.aciUuid().toString()), + List.of(HttpHeaders.CONTENT_TYPE + ":" + MediaType.APPLICATION_JSON), + messages); + assertEquals(200, responseMessage.getStatus()); + assertDoesNotThrow(() -> WebsocketClientSession.decode(SendMessageResponse.class, responseMessage)); - final Pair receiveMessages = Operations.apiGet("/v1/messages") - .authorized(userB) - .execute(OutgoingMessageEntityList.class); + final WebsocketClientSession websocketB = Operations.authenticatedWebsocket(userB, Device.PRIMARY_ID); + assertTimeoutPreemptively(Duration.ofSeconds(5), websocketB::waitForQueueEmpty); - final byte[] actualContent = receiveMessages.getRight().messages().getFirst().content(); - assertArrayEquals(expectedContent, actualContent); - } finally { - Operations.deleteUser(userA); - Operations.deleteUser(userB); - } + assertEquals(1, websocketB.getReceivedEnvelopes().size()); + final MessageProtos.Envelope envelope = websocketB.getReceivedEnvelopes().getFirst(); + assertArrayEquals(expectedContent, envelope.getContent().toByteArray()); + + websocketB.close(1000); + websocketA.close(1000); } }