diff --git a/integration-tests/src/main/java/org/signal/integration/WebsocketClientSession.java b/integration-tests/src/main/java/org/signal/integration/WebsocketClientSession.java index 7ab5259fd..e4a8c63ac 100644 --- a/integration-tests/src/main/java/org/signal/integration/WebsocketClientSession.java +++ b/integration-tests/src/main/java/org/signal/integration/WebsocketClientSession.java @@ -62,7 +62,7 @@ public class WebsocketClientSession implements Session.Listener.AutoDemanding { log.info("received response message {}", response.getStatus()); final CompletableFuture future = responseFutures.remove(response.getRequestId()); if (future == null) { - throw new IllegalArgumentException("Received response with no matching request: {}" + response.getRequestId()); + throw new IllegalArgumentException("Received response with no matching request: " + response.getRequestId()); } future.complete(response); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MessageMetrics.java b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MessageMetrics.java index 69dff1991..6e8edae17 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MessageMetrics.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MessageMetrics.java @@ -35,6 +35,7 @@ public final class MessageMetrics { public static String GRPC_CHANNEL = "grpc"; public static String WEBSOCKET_CHANNEL = "websocket"; + public static String MESSAGELESS_WEBSOCKET_CHANNEL = "messageless-websocket"; @VisibleForTesting static final String MISMATCHED_ACCOUNT_ENVELOPE_UUID_COUNTER_NAME = name(MessageMetrics.class, diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java index da8206e82..f089ff84a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java @@ -38,6 +38,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { private final AccountsManager accountsManager; private final DisconnectionRequestManager disconnectionRequestManager; private final WebSocketConnectionBuilder webSocketConnectionBuilder; + private final MessageMetrics messageMetrics; private final OpenWebSocketCounter openAuthenticatedWebSocketCounter; private final OpenWebSocketCounter openUnauthenticatedWebSocketCounter; @@ -66,6 +67,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { disconnectionRequestManager, asnInfoProviderSupplier, clientReleaseManager, + messageMetrics, (account, device, client) -> new WebSocketConnection(receiptSender, messagesManager, messageMetrics, @@ -86,11 +88,13 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { final DisconnectionRequestManager disconnectionRequestManager, final Supplier asnInfoProviderSupplier, final ClientReleaseManager clientReleaseManager, + final MessageMetrics messageMetrics, final WebSocketConnectionBuilder webSocketConnectionBuilder) { this.accountsManager = accountsManager; this.disconnectionRequestManager = disconnectionRequestManager; this.webSocketConnectionBuilder = webSocketConnectionBuilder; + this.messageMetrics = messageMetrics; this.openAuthenticatedWebSocketCounter = new OpenWebSocketCounter("rpc-authenticated", asnInfoProviderSupplier, clientReleaseManager); this.openUnauthenticatedWebSocketCounter = new OpenWebSocketCounter("rpc-unauthenticated", asnInfoProviderSupplier, clientReleaseManager); @@ -102,46 +106,52 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { final boolean authenticated = (context.getAuthenticated() != null); (authenticated ? openAuthenticatedWebSocketCounter : openUnauthenticatedWebSocketCounter).countOpenWebSocket(context); + if (!authenticated) { + return; + } - if (authenticated) { - final AuthenticatedDevice auth = context.getAuthenticated(AuthenticatedDevice.class); + final AuthenticatedDevice auth = context.getAuthenticated(AuthenticatedDevice.class); - final Optional maybeAuthenticatedAccount = - accountsManager.getByAccountIdentifier(auth.accountIdentifier()); + final Optional maybeAuthenticatedAccount = + accountsManager.getByAccountIdentifier(auth.accountIdentifier()); - final Optional maybeAuthenticatedDevice = - maybeAuthenticatedAccount.flatMap(account -> account.getDevice(auth.deviceId())); + final Optional maybeAuthenticatedDevice = + maybeAuthenticatedAccount.flatMap(account -> account.getDevice(auth.deviceId())); - if (maybeAuthenticatedAccount.isEmpty() || maybeAuthenticatedDevice.isEmpty()) { - log.warn("{}:{} not found when opening authenticated WebSocket", auth.accountIdentifier(), auth.deviceId()); + if (maybeAuthenticatedAccount.isEmpty() || maybeAuthenticatedDevice.isEmpty()) { + log.warn("{}:{} not found when opening authenticated WebSocket", auth.accountIdentifier(), auth.deviceId()); - context.getClient().close(1011, "Unexpected error initializing connection"); - return; - } + context.getClient().close(1011, "Unexpected error initializing connection"); + return; + } - final WebSocketConnection connection = - webSocketConnectionBuilder.buildWebSocketConnection(maybeAuthenticatedAccount.get(), - maybeAuthenticatedDevice.get(), - context.getClient()); + final Account account = maybeAuthenticatedAccount.get(); + final Device device = maybeAuthenticatedDevice.get(); - disconnectionRequestManager.addListener(maybeAuthenticatedAccount.get().getIdentifier(IdentityType.ACI), - maybeAuthenticatedDevice.get().getId(), - connection); + final boolean disableMessages = context.getClient().shouldDisableMessages(); - context.addWebsocketClosedListener((_, _, _) -> { - disconnectionRequestManager.removeListener(maybeAuthenticatedAccount.get().getIdentifier(IdentityType.ACI), - maybeAuthenticatedDevice.get().getId(), - connection); + final Optional maybeWebSocketConnection = disableMessages + ? Optional.empty() + : Optional.of(webSocketConnectionBuilder.buildWebSocketConnection(account, device, context.getClient())); - connection.stop(); - }); + final WebSocketDisconnectionRequestListener disconnectionListener = + new WebSocketDisconnectionRequestListener(messageMetrics, context.getClient(), disableMessages); - try { - connection.start(); - } catch (final Exception e) { - log.warn("Failed to initialize websocket", e); - context.getClient().close(1011, "Unexpected error initializing connection"); - } + disconnectionRequestManager + .addListener(account.getIdentifier(IdentityType.ACI), device.getId(), disconnectionListener); + + context.addWebsocketClosedListener((_, _, _) -> { + disconnectionRequestManager + .removeListener(account.getIdentifier(IdentityType.ACI), device.getId(), disconnectionListener); + maybeWebSocketConnection.ifPresent(WebSocketConnection::stop); + }); + + try { + maybeWebSocketConnection.ifPresent(WebSocketConnection::start); + } catch (final Exception e) { + log.warn("Failed to initialize websocket", e); + context.getClient().close(1011, "Unexpected error initializing connection"); } } + } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index ede79d7c2..d2b204cdc 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -15,7 +15,6 @@ import io.micrometer.core.instrument.Tags; import io.micrometer.core.instrument.Timer; import java.time.Duration; import java.util.Collections; -import java.util.HexFormat; import java.util.List; import java.util.Optional; import java.util.UUID; @@ -29,7 +28,6 @@ import org.apache.commons.lang3.StringUtils; import org.eclipse.jetty.util.StaticException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; @@ -62,7 +60,7 @@ import reactor.core.observability.micrometer.Micrometer; import reactor.core.publisher.Mono; import reactor.core.scheduler.Scheduler; -public class WebSocketConnection implements DisconnectionRequestListener { +public class WebSocketConnection { private static final Counter sendFailuresCounter = Metrics.counter(name(WebSocketConnection.class, "sendFailures")); @@ -333,10 +331,4 @@ public class WebSocketConnection implements DisconnectionRequestListener { throwable instanceof org.eclipse.jetty.io.EofException || (throwable instanceof StaticException staticException && "Closed".equals(staticException.getMessage())); } - - @Override - public void handleDisconnectionRequest() { - messageMetrics.measureMessageStreamDisplaced(MessageMetrics.WEBSOCKET_CHANNEL, userAgent, false); - client.close(4401, "Reauthentication required"); - } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketDisconnectionRequestListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketDisconnectionRequestListener.java new file mode 100644 index 000000000..41f09ada5 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketDisconnectionRequestListener.java @@ -0,0 +1,33 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.websocket; + +import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener; +import org.whispersystems.textsecuregcm.metrics.MessageMetrics; +import org.whispersystems.textsecuregcm.util.ua.UserAgent; +import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; +import org.whispersystems.websocket.WebSocketClient; + +class WebSocketDisconnectionRequestListener implements DisconnectionRequestListener { + + private final MessageMetrics messageMetrics; + private final String metricChannel; + private final WebSocketClient client; + + WebSocketDisconnectionRequestListener(final MessageMetrics messageMetrics, final WebSocketClient client, final boolean disableMessages) { + this.messageMetrics = messageMetrics; + this.client = client; + this.metricChannel = disableMessages + ? MessageMetrics.MESSAGELESS_WEBSOCKET_CHANNEL + : MessageMetrics.WEBSOCKET_CHANNEL; + } + + @Override + public void handleDisconnectionRequest() { + final UserAgent userAgent = UserAgentUtil.maybeParseUserAgentString(client.getUserAgent()); + messageMetrics.measureMessageStreamDisplaced(metricChannel, userAgent, false); + client.close(4401, "Reauthentication required"); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/TestWebsocketListener.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/TestWebsocketListener.java index 44d904646..1ad71f173 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/TestWebsocketListener.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/TestWebsocketListener.java @@ -4,25 +4,37 @@ */ package org.whispersystems.textsecuregcm.tests.util; +import java.io.IOException; +import java.io.UncheckedIOException; import java.nio.ByteBuffer; +import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.atomic.AtomicLong; +import com.google.protobuf.InvalidProtocolBufferException; import org.eclipse.jetty.websocket.api.Callback; import org.eclipse.jetty.websocket.api.Session; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.websocket.messages.WebSocketMessage; import org.whispersystems.websocket.messages.WebSocketMessageFactory; +import org.whispersystems.websocket.messages.WebSocketRequestMessage; import org.whispersystems.websocket.messages.WebSocketResponseMessage; import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory; public class TestWebsocketListener implements Session.Listener.AutoDemanding { + private static final Logger log = LoggerFactory.getLogger(TestWebsocketListener.class); private final AtomicLong requestId = new AtomicLong(); - private final CompletableFuture started = new CompletableFuture<>(); + protected final CompletableFuture started = new CompletableFuture<>(); + private final CompletableFuture queueEmpty = new CompletableFuture<>(); private final CompletableFuture closed = new CompletableFuture<>(); + private final List receivedEnvelopes = new CopyOnWriteArrayList<>(); private final ConcurrentHashMap> responseFutures = new ConcurrentHashMap<>(); protected final WebSocketMessageFactory messageFactory; @@ -46,6 +58,10 @@ public class TestWebsocketListener implements Session.Listener.AutoDemanding { return closed; } + public CompletableFuture queueEmptyFuture() { + return queueEmpty; + } + public CompletableFuture doGet(final String requestPath) { return sendRequest(requestPath, "GET", List.of("Accept: application/json"), Optional.empty()); } @@ -66,18 +82,55 @@ public class TestWebsocketListener implements Session.Listener.AutoDemanding { }); } + public List getReceivedEnvelopes() { + return receivedEnvelopes; + } + @Override public void onWebSocketBinary(final ByteBuffer payload, final Callback callback) { try { - WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload); - if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.RESPONSE_MESSAGE) { - responseFutures.get(webSocketMessage.getResponseMessage().getRequestId()) - .complete(webSocketMessage.getResponseMessage()); - } else { - throw new RuntimeException("Unexpected message type: " + webSocketMessage.getType()); + final WebSocketMessage message = messageFactory.parseMessage(payload); + switch (message.getType()) { + case REQUEST_MESSAGE -> { + log.info("received request message {} {}", message.getRequestMessage().getVerb(), message.getRequestMessage().getPath()); + switch (message.getRequestMessage().getPath()) { + case "/api/v1/message" -> acknowledge(message.getRequestMessage()); + case "/api/v1/queue/empty" -> queueEmpty.complete(null); + default -> throw new IllegalStateException("Unexpected path: " + message.getRequestMessage().getPath()); + } + } + case RESPONSE_MESSAGE -> { + final WebSocketResponseMessage response = message.getResponseMessage(); + log.info("received response message {}", response.getStatus()); + final CompletableFuture future = responseFutures.remove(response.getRequestId()); + if (future == null) { + throw new IllegalArgumentException("Received response with no matching request: " + response.getRequestId()); + } + future.complete(response); + } + default -> throw new IllegalStateException("Unexpected message type: " + message.getType()); } + callback.succeed(); } catch (final Exception e) { - throw new RuntimeException(e); + log.warn("Failed to process message received over the websocket", e); + callback.fail(e); + started.join().close(1006, e.getMessage(), Callback.NOOP); + } + } + + private void acknowledge(WebSocketRequestMessage message) { + final byte[] envelopeBytes = message.getBody() + .orElseThrow(() -> new UncheckedIOException(new IOException("Messages should have a body"))); + try { + final MessageProtos.Envelope envelope = MessageProtos.Envelope.parseFrom(envelopeBytes); + receivedEnvelopes.add(envelope); + final Session session = started.join(); + final WebSocketMessage response = messageFactory.createResponse(message.getRequestId(), 200, "", + Collections.emptyList(), Optional.empty()); + session.sendBinary(ByteBuffer.wrap(response.toByteArray()), Callback.NOOP); + + } catch (InvalidProtocolBufferException e) { + throw new UncheckedIOException(e); } } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListenerIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListenerIntegrationTest.java new file mode 100644 index 000000000..223ac2907 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListenerIntegrationTest.java @@ -0,0 +1,230 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.websocket; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME; + +import com.google.i18n.phonenumbers.PhoneNumberUtil; +import com.google.protobuf.ByteString; +import io.dropwizard.core.Application; +import io.dropwizard.core.Configuration; +import io.dropwizard.core.setup.Environment; +import io.dropwizard.testing.junit5.DropwizardAppExtension; +import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; +import java.net.URI; +import java.time.Instant; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.apache.commons.lang3.RandomStringUtils; +import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer; +import org.eclipse.jetty.websocket.api.Session; +import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; +import org.eclipse.jetty.websocket.client.WebSocketClient; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.mockito.ArgumentCaptor; +import org.whispersystems.textsecuregcm.asn.AsnInfoProvider; +import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener; +import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager; +import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; +import org.whispersystems.textsecuregcm.filters.PriorityFilter; +import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor; +import org.whispersystems.textsecuregcm.metrics.MessageMetrics; +import org.whispersystems.textsecuregcm.push.PushNotificationManager; +import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; +import org.whispersystems.textsecuregcm.push.ReceiptSender; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.MessageStream; +import org.whispersystems.textsecuregcm.storage.MessageStreamEntry; +import org.whispersystems.textsecuregcm.storage.MessagesManager; +import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener; +import org.whispersystems.textsecuregcm.util.UUIDUtil; +import org.whispersystems.websocket.WebSocketResourceProviderFactory; +import org.whispersystems.websocket.WebsocketHeaders; +import org.whispersystems.websocket.configuration.WebSocketConfiguration; +import org.whispersystems.websocket.setup.WebSocketEnvironment; +import reactor.adapter.JdkFlowAdapter; +import reactor.core.publisher.Flux; +import reactor.core.scheduler.Schedulers; + +@ExtendWith(DropwizardExtensionsSupport.class) +@Timeout(value = 15, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) +class AuthenticatedConnectListenerIntegrationTest { + + @RegisterExtension + static final DropwizardAppExtension DROPWIZARD_APP_EXTENSION = + new DropwizardAppExtension<>(TestApplication.class); + + private static final UUID ACCOUNT_UUID = UUID.randomUUID(); + private static final byte DEVICE_ID = Device.PRIMARY_ID; + private static final String E164 = PhoneNumberUtil.getInstance() + .format(PhoneNumberUtil.getInstance().getExampleNumber("US"), PhoneNumberUtil.PhoneNumberFormat.E164); + + + private static final DisconnectionRequestManager disconnectionRequestManager = mock(DisconnectionRequestManager.class); + private static final MessagesManager messagesManager = mock(MessagesManager.class); + private static final AccountsManager accountsManager = mock(AccountsManager.class); + private static final Account account = mock(Account.class); + private static final Device device = mock(Device.class); + + private WebSocketClient client; + + @BeforeEach + void setUp() throws Exception { + reset(messagesManager, disconnectionRequestManager, accountsManager, account, device); + when(messagesManager.mayHaveMessages(any(), any())).thenReturn(CompletableFuture.completedFuture(false)); + when(account.getNumber()).thenReturn(E164); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(ACCOUNT_UUID); + when(account.getDevice(DEVICE_ID)).thenReturn(Optional.of(device)); + when(device.getId()).thenReturn(DEVICE_ID); + when(accountsManager.getByAccountIdentifier(ACCOUNT_UUID)).thenReturn(Optional.of(account)); + + client = new WebSocketClient(); + client.start(); + } + + @AfterEach + void tearDown() throws Exception { + client.stop(); + } + + public static class TestApplication extends Application { + + @Override + public void run(final Configuration configuration, final Environment environment) { + final AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener( + accountsManager, + mock(ReceiptSender.class), + messagesManager, + new MessageMetrics(), + mock(PushNotificationManager.class), + mock(PushNotificationScheduler.class), + disconnectionRequestManager, + Schedulers.boundedElastic(), + () -> mock(AsnInfoProvider.class), + mock(ClientReleaseManager.class), + mock(MessageDeliveryLoopMonitor.class), + mock(ExperimentEnrollmentManager.class)); + + final WebSocketEnvironment webSocketEnvironment = + new WebSocketEnvironment<>(environment, new WebSocketConfiguration()); + + webSocketEnvironment.setAuthenticator(_ -> + Optional.of(new AuthenticatedDevice(ACCOUNT_UUID, DEVICE_ID, Instant.now()))); + webSocketEnvironment.setConnectListener(connectListener); + + final WebSocketResourceProviderFactory webSocketServlet = + new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedDevice.class, + REMOTE_ADDRESS_ATTRIBUTE_NAME); + + JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), + (servletContext, container) -> { + container.addMapping("/websocket", webSocketServlet); + PriorityFilter.ensureFilter(servletContext, new RemoteAddressFilter()); + }); + } + } + + @Test + void messagesDeliveredOnlyWhenHeaderAbsent() throws Exception { + final UUID messageGuid = UUID.randomUUID(); + final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid); + + final MessageStream messageStream = mock(MessageStream.class); + when(messageStream.getMessages()).thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.just( + new MessageStreamEntry.Envelope(envelope), + new MessageStreamEntry.QueueEmpty()))); + when(messageStream.acknowledgeMessage(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(null)); + when(messagesManager.getMessages(ACCOUNT_UUID, device)).thenReturn(messageStream); + final URI uri = URI.create("ws://127.0.0.1:" + DROPWIZARD_APP_EXTENSION.getLocalPort() + "/websocket"); + + final TestWebsocketListener disabledListener = new TestWebsocketListener(); + final ClientUpgradeRequest disabledRequest = new ClientUpgradeRequest(uri); + disabledRequest.setHeader(WebsocketHeaders.X_SIGNAL_DISABLE_MESSAGES, "true"); + try (Session _ = client.connect(disabledListener, disabledRequest).get(5, TimeUnit.SECONDS)) { + assertThrows(TimeoutException.class, () -> disabledListener.queueEmptyFuture().get(10, TimeUnit.MILLISECONDS)); + assertTrue(disabledListener.getReceivedEnvelopes().isEmpty()); + assertFalse(disabledListener.queueEmptyFuture().isDone()); + } + + final TestWebsocketListener enabledListener = new TestWebsocketListener(); + try (Session ignored = client.connect(enabledListener, uri).get(5, TimeUnit.SECONDS)) { + enabledListener.queueEmptyFuture().get(5, TimeUnit.SECONDS); + assertEquals(1, enabledListener.getReceivedEnvelopes().size()); + assertEquals( + UUIDUtil.toByteString(messageGuid), + enabledListener.getReceivedEnvelopes().getFirst().getServerGuid()); + } + } + + @Test + void allDisconnected() throws Exception { + final MessageStream messageStream = mock(MessageStream.class); + when(messageStream.getMessages()) + .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.never())); + when(messagesManager.getMessages(ACCOUNT_UUID, device)).thenReturn(messageStream); + + final URI uri = URI.create("ws://127.0.0.1:" + DROPWIZARD_APP_EXTENSION.getLocalPort() + "/websocket"); + + final TestWebsocketListener disabledListener = new TestWebsocketListener(); + final TestWebsocketListener enabledListener = new TestWebsocketListener(); + final ClientUpgradeRequest disabledRequest = new ClientUpgradeRequest(uri); + disabledRequest.setHeader(WebsocketHeaders.X_SIGNAL_DISABLE_MESSAGES, "true"); + + final ArgumentCaptor captor = + ArgumentCaptor.forClass(DisconnectionRequestListener.class); + client.connect(disabledListener, disabledRequest).get(5, TimeUnit.SECONDS); + client.connect(enabledListener, uri).get(5, TimeUnit.SECONDS); + + // Simulate a disconnection request + verify(disconnectionRequestManager, timeout(1000).times(2)) + .addListener(eq(ACCOUNT_UUID), eq(DEVICE_ID), captor.capture()); + captor.getAllValues().forEach(listener -> listener.handleDisconnectionRequest()); + + assertEquals(4401, disabledListener.closeFuture().join()); + assertEquals(4401, enabledListener.closeFuture().join()); + } + + private static MessageProtos.Envelope generateRandomMessage(final UUID messageGuid) { + final long timestamp = System.currentTimeMillis(); + return MessageProtos.Envelope.newBuilder() + .setClientTimestamp(timestamp) + .setServerTimestamp(timestamp) + .setContent(ByteString.copyFromUtf8(RandomStringUtils.secure().nextAlphanumeric(256))) + .setType(MessageProtos.Envelope.Type.CIPHERTEXT) + .setServerGuid(UUIDUtil.toByteString(messageGuid)) + .setSourceServiceId(new AciServiceIdentifier(UUID.randomUUID()).toCompactByteString()) + .setDestinationServiceId(new AciServiceIdentifier(ACCOUNT_UUID).toCompactByteString()) + .build(); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListenerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListenerTest.java index de7135691..537ea3ef9 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListenerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListenerTest.java @@ -21,10 +21,13 @@ import java.util.Optional; import java.util.UUID; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import org.whispersystems.textsecuregcm.asn.AsnInfoProvider; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener; import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager; import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; @@ -58,6 +61,7 @@ class AuthenticatedConnectListenerTest { disconnectionRequestManager, () -> mock(AsnInfoProvider.class), mock(ClientReleaseManager.class), + mock(MessageMetrics.class), (_, _, _) -> authenticatedWebSocketConnection); final Device device = mock(Device.class); @@ -83,10 +87,18 @@ class AuthenticatedConnectListenerTest { authenticatedConnectListener.onWebSocketConnect(webSocketSessionContext); - verify(disconnectionRequestManager).addListener(ACCOUNT_IDENTIFIER, DEVICE_ID, authenticatedWebSocketConnection); + final ArgumentCaptor disconnectListener = ArgumentCaptor.forClass(DisconnectionRequestListener.class); + final ArgumentCaptor closeListener = + ArgumentCaptor.forClass(WebSocketSessionContext.WebSocketEventListener.class); + verify(disconnectionRequestManager).addListener(eq(ACCOUNT_IDENTIFIER), eq(DEVICE_ID), disconnectListener.capture()); // We expect one call from AuthenticatedConnectListener itself and one from OpenWebSocketCounter - verify(webSocketSessionContext, times(2)).addWebsocketClosedListener(any()); + verify(webSocketSessionContext, times(2)).addWebsocketClosedListener(closeListener.capture()); verify(authenticatedWebSocketConnection).start(); + + // Verify that if we run the close listeners, the disconnection listener gets removed + closeListener.getAllValues() + .forEach(c -> c.onWebSocketClose(webSocketSessionContext, 1011, "test")); + verify(disconnectionRequestManager).removeListener(ACCOUNT_IDENTIFIER, DEVICE_ID, disconnectListener.getValue()); } @Test @@ -118,12 +130,21 @@ class AuthenticatedConnectListenerTest { authenticatedConnectListener.onWebSocketConnect(webSocketSessionContext); - verify(disconnectionRequestManager).addListener(ACCOUNT_IDENTIFIER, DEVICE_ID, authenticatedWebSocketConnection); - // We expect one call from AuthenticatedConnectListener itself and one from OpenWebSocketCounter - verify(webSocketSessionContext, times(2)).addWebsocketClosedListener(any()); + final ArgumentCaptor disconnectListener = ArgumentCaptor.forClass(DisconnectionRequestListener.class); + final ArgumentCaptor closeListener = + ArgumentCaptor.forClass(WebSocketSessionContext.WebSocketEventListener.class); + + // We expect one call from OpenWebSocketCounter and one from AuthenticatedConnectListener itself + verify(webSocketSessionContext, times(2)).addWebsocketClosedListener(closeListener.capture()); + verify(disconnectionRequestManager).addListener(eq(ACCOUNT_IDENTIFIER), eq(DEVICE_ID), disconnectListener.capture()); verify(authenticatedWebSocketConnection).start(); verify(webSocketClient).close(eq(1011), anyString()); + + // Verify that if we run the close listeners, the disconnection listener gets removed + closeListener.getAllValues() + .forEach(c -> c.onWebSocketClose(webSocketSessionContext, 1011, "test")); + verify(disconnectionRequestManager).removeListener(ACCOUNT_IDENTIFIER, DEVICE_ID, disconnectListener.getValue()); } @Test diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketClient.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketClient.java index bc1cf9da9..b2597dd3f 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketClient.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketClient.java @@ -97,7 +97,12 @@ public class WebSocketClient { public boolean shouldDeliverStories() { String value = session.getUpgradeRequest().getHeader(WebsocketHeaders.X_SIGNAL_RECEIVE_STORIES); - return WebsocketHeaders.parseReceiveStoriesHeader(value); + return WebsocketHeaders.parseBooleanHeader(value); + } + + public boolean shouldDisableMessages() { + final String value = session.getUpgradeRequest().getHeader(WebsocketHeaders.X_SIGNAL_DISABLE_MESSAGES); + return WebsocketHeaders.parseBooleanHeader(value); } private long generateRequestId() { diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebsocketHeaders.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebsocketHeaders.java index 73703005d..7b5b6624f 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebsocketHeaders.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebsocketHeaders.java @@ -6,7 +6,13 @@ package org.whispersystems.websocket; public class WebsocketHeaders { public final static String X_SIGNAL_RECEIVE_STORIES = "X-Signal-Receive-Stories"; - public static boolean parseReceiveStoriesHeader(String s) { + public final static String X_SIGNAL_DISABLE_MESSAGES = "X-Signal-Disable-Messages"; + + /// Parse a boolean header value + /// + /// @param s the value of an HTTP header + /// @return true if the header is "true", otherwise false + public static boolean parseBooleanHeader(String s) { return "true".equals(s); } }