Add header to disable messages on authenticated websocket
Some checks failed
Update Documentation / build (push) Has been cancelled
Service CI / build (push) Has been cancelled
Integration Tests / build (push) Has been cancelled

This commit is contained in:
Ravi Khadiwala 2026-06-12 11:43:54 -05:00 committed by ravi-signal
parent 8deb5a803a
commit ac720595e6
10 changed files with 406 additions and 55 deletions

View File

@ -62,7 +62,7 @@ public class WebsocketClientSession implements Session.Listener.AutoDemanding {
log.info("received response message {}", response.getStatus());
final CompletableFuture<WebSocketResponseMessage> future = responseFutures.remove(response.getRequestId());
if (future == null) {
throw new IllegalArgumentException("Received response with no matching request: {}" + response.getRequestId());
throw new IllegalArgumentException("Received response with no matching request: " + response.getRequestId());
}
future.complete(response);
}

View File

@ -35,6 +35,7 @@ public final class MessageMetrics {
public static String GRPC_CHANNEL = "grpc";
public static String WEBSOCKET_CHANNEL = "websocket";
public static String MESSAGELESS_WEBSOCKET_CHANNEL = "messageless-websocket";
@VisibleForTesting
static final String MISMATCHED_ACCOUNT_ENVELOPE_UUID_COUNTER_NAME = name(MessageMetrics.class,

View File

@ -38,6 +38,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
private final AccountsManager accountsManager;
private final DisconnectionRequestManager disconnectionRequestManager;
private final WebSocketConnectionBuilder webSocketConnectionBuilder;
private final MessageMetrics messageMetrics;
private final OpenWebSocketCounter openAuthenticatedWebSocketCounter;
private final OpenWebSocketCounter openUnauthenticatedWebSocketCounter;
@ -66,6 +67,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
disconnectionRequestManager,
asnInfoProviderSupplier,
clientReleaseManager,
messageMetrics,
(account, device, client) -> new WebSocketConnection(receiptSender,
messagesManager,
messageMetrics,
@ -86,11 +88,13 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
final DisconnectionRequestManager disconnectionRequestManager,
final Supplier<AsnInfoProvider> asnInfoProviderSupplier,
final ClientReleaseManager clientReleaseManager,
final MessageMetrics messageMetrics,
final WebSocketConnectionBuilder webSocketConnectionBuilder) {
this.accountsManager = accountsManager;
this.disconnectionRequestManager = disconnectionRequestManager;
this.webSocketConnectionBuilder = webSocketConnectionBuilder;
this.messageMetrics = messageMetrics;
this.openAuthenticatedWebSocketCounter = new OpenWebSocketCounter("rpc-authenticated", asnInfoProviderSupplier, clientReleaseManager);
this.openUnauthenticatedWebSocketCounter = new OpenWebSocketCounter("rpc-unauthenticated", asnInfoProviderSupplier, clientReleaseManager);
@ -102,8 +106,10 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
final boolean authenticated = (context.getAuthenticated() != null);
(authenticated ? openAuthenticatedWebSocketCounter : openUnauthenticatedWebSocketCounter).countOpenWebSocket(context);
if (!authenticated) {
return;
}
if (authenticated) {
final AuthenticatedDevice auth = context.getAuthenticated(AuthenticatedDevice.class);
final Optional<Account> maybeAuthenticatedAccount =
@ -119,29 +125,33 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
return;
}
final WebSocketConnection connection =
webSocketConnectionBuilder.buildWebSocketConnection(maybeAuthenticatedAccount.get(),
maybeAuthenticatedDevice.get(),
context.getClient());
final Account account = maybeAuthenticatedAccount.get();
final Device device = maybeAuthenticatedDevice.get();
disconnectionRequestManager.addListener(maybeAuthenticatedAccount.get().getIdentifier(IdentityType.ACI),
maybeAuthenticatedDevice.get().getId(),
connection);
final boolean disableMessages = context.getClient().shouldDisableMessages();
final Optional<WebSocketConnection> maybeWebSocketConnection = disableMessages
? Optional.empty()
: Optional.of(webSocketConnectionBuilder.buildWebSocketConnection(account, device, context.getClient()));
final WebSocketDisconnectionRequestListener disconnectionListener =
new WebSocketDisconnectionRequestListener(messageMetrics, context.getClient(), disableMessages);
disconnectionRequestManager
.addListener(account.getIdentifier(IdentityType.ACI), device.getId(), disconnectionListener);
context.addWebsocketClosedListener((_, _, _) -> {
disconnectionRequestManager.removeListener(maybeAuthenticatedAccount.get().getIdentifier(IdentityType.ACI),
maybeAuthenticatedDevice.get().getId(),
connection);
connection.stop();
disconnectionRequestManager
.removeListener(account.getIdentifier(IdentityType.ACI), device.getId(), disconnectionListener);
maybeWebSocketConnection.ifPresent(WebSocketConnection::stop);
});
try {
connection.start();
maybeWebSocketConnection.ifPresent(WebSocketConnection::start);
} catch (final Exception e) {
log.warn("Failed to initialize websocket", e);
context.getClient().close(1011, "Unexpected error initializing connection");
}
}
}
}

View File

@ -15,7 +15,6 @@ import io.micrometer.core.instrument.Tags;
import io.micrometer.core.instrument.Timer;
import java.time.Duration;
import java.util.Collections;
import java.util.HexFormat;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
@ -29,7 +28,6 @@ import org.apache.commons.lang3.StringUtils;
import org.eclipse.jetty.util.StaticException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
@ -62,7 +60,7 @@ import reactor.core.observability.micrometer.Micrometer;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
public class WebSocketConnection implements DisconnectionRequestListener {
public class WebSocketConnection {
private static final Counter sendFailuresCounter = Metrics.counter(name(WebSocketConnection.class, "sendFailures"));
@ -333,10 +331,4 @@ public class WebSocketConnection implements DisconnectionRequestListener {
throwable instanceof org.eclipse.jetty.io.EofException ||
(throwable instanceof StaticException staticException && "Closed".equals(staticException.getMessage()));
}
@Override
public void handleDisconnectionRequest() {
messageMetrics.measureMessageStreamDisplaced(MessageMetrics.WEBSOCKET_CHANNEL, userAgent, false);
client.close(4401, "Reauthentication required");
}
}

View File

@ -0,0 +1,33 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.websocket;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import org.whispersystems.websocket.WebSocketClient;
class WebSocketDisconnectionRequestListener implements DisconnectionRequestListener {
private final MessageMetrics messageMetrics;
private final String metricChannel;
private final WebSocketClient client;
WebSocketDisconnectionRequestListener(final MessageMetrics messageMetrics, final WebSocketClient client, final boolean disableMessages) {
this.messageMetrics = messageMetrics;
this.client = client;
this.metricChannel = disableMessages
? MessageMetrics.MESSAGELESS_WEBSOCKET_CHANNEL
: MessageMetrics.WEBSOCKET_CHANNEL;
}
@Override
public void handleDisconnectionRequest() {
final UserAgent userAgent = UserAgentUtil.maybeParseUserAgentString(client.getUserAgent());
messageMetrics.measureMessageStreamDisplaced(metricChannel, userAgent, false);
client.close(4401, "Reauthentication required");
}
}

View File

@ -4,25 +4,37 @@
*/
package org.whispersystems.textsecuregcm.tests.util;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicLong;
import com.google.protobuf.InvalidProtocolBufferException;
import org.eclipse.jetty.websocket.api.Callback;
import org.eclipse.jetty.websocket.api.Session;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.websocket.messages.WebSocketMessage;
import org.whispersystems.websocket.messages.WebSocketMessageFactory;
import org.whispersystems.websocket.messages.WebSocketRequestMessage;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
public class TestWebsocketListener implements Session.Listener.AutoDemanding {
private static final Logger log = LoggerFactory.getLogger(TestWebsocketListener.class);
private final AtomicLong requestId = new AtomicLong();
private final CompletableFuture<Session> started = new CompletableFuture<>();
protected final CompletableFuture<Session> started = new CompletableFuture<>();
private final CompletableFuture<Void> queueEmpty = new CompletableFuture<>();
private final CompletableFuture<Integer> closed = new CompletableFuture<>();
private final List<MessageProtos.Envelope> receivedEnvelopes = new CopyOnWriteArrayList<>();
private final ConcurrentHashMap<Long, CompletableFuture<WebSocketResponseMessage>> responseFutures = new ConcurrentHashMap<>();
protected final WebSocketMessageFactory messageFactory;
@ -46,6 +58,10 @@ public class TestWebsocketListener implements Session.Listener.AutoDemanding {
return closed;
}
public CompletableFuture<Void> queueEmptyFuture() {
return queueEmpty;
}
public CompletableFuture<WebSocketResponseMessage> doGet(final String requestPath) {
return sendRequest(requestPath, "GET", List.of("Accept: application/json"), Optional.empty());
}
@ -66,18 +82,55 @@ public class TestWebsocketListener implements Session.Listener.AutoDemanding {
});
}
public List<MessageProtos.Envelope> getReceivedEnvelopes() {
return receivedEnvelopes;
}
@Override
public void onWebSocketBinary(final ByteBuffer payload, final Callback callback) {
try {
WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload);
if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.RESPONSE_MESSAGE) {
responseFutures.get(webSocketMessage.getResponseMessage().getRequestId())
.complete(webSocketMessage.getResponseMessage());
} else {
throw new RuntimeException("Unexpected message type: " + webSocketMessage.getType());
final WebSocketMessage message = messageFactory.parseMessage(payload);
switch (message.getType()) {
case REQUEST_MESSAGE -> {
log.info("received request message {} {}", message.getRequestMessage().getVerb(), message.getRequestMessage().getPath());
switch (message.getRequestMessage().getPath()) {
case "/api/v1/message" -> acknowledge(message.getRequestMessage());
case "/api/v1/queue/empty" -> queueEmpty.complete(null);
default -> throw new IllegalStateException("Unexpected path: " + message.getRequestMessage().getPath());
}
}
case RESPONSE_MESSAGE -> {
final WebSocketResponseMessage response = message.getResponseMessage();
log.info("received response message {}", response.getStatus());
final CompletableFuture<WebSocketResponseMessage> future = responseFutures.remove(response.getRequestId());
if (future == null) {
throw new IllegalArgumentException("Received response with no matching request: " + response.getRequestId());
}
future.complete(response);
}
default -> throw new IllegalStateException("Unexpected message type: " + message.getType());
}
callback.succeed();
} catch (final Exception e) {
throw new RuntimeException(e);
log.warn("Failed to process message received over the websocket", e);
callback.fail(e);
started.join().close(1006, e.getMessage(), Callback.NOOP);
}
}
private void acknowledge(WebSocketRequestMessage message) {
final byte[] envelopeBytes = message.getBody()
.orElseThrow(() -> new UncheckedIOException(new IOException("Messages should have a body")));
try {
final MessageProtos.Envelope envelope = MessageProtos.Envelope.parseFrom(envelopeBytes);
receivedEnvelopes.add(envelope);
final Session session = started.join();
final WebSocketMessage response = messageFactory.createResponse(message.getRequestId(), 200, "",
Collections.emptyList(), Optional.empty());
session.sendBinary(ByteBuffer.wrap(response.toByteArray()), Callback.NOOP);
} catch (InvalidProtocolBufferException e) {
throw new UncheckedIOException(e);
}
}
}

View File

@ -0,0 +1,230 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.websocket;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME;
import com.google.i18n.phonenumbers.PhoneNumberUtil;
import com.google.protobuf.ByteString;
import io.dropwizard.core.Application;
import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.testing.junit5.DropwizardAppExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import java.net.URI;
import java.time.Instant;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.apache.commons.lang3.RandomStringUtils;
import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.asn.AsnInfoProvider;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.filters.PriorityFilter;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessageStream;
import org.whispersystems.textsecuregcm.storage.MessageStreamEntry;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
import org.whispersystems.websocket.WebsocketHeaders;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
import reactor.adapter.JdkFlowAdapter;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;
@ExtendWith(DropwizardExtensionsSupport.class)
@Timeout(value = 15, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
class AuthenticatedConnectListenerIntegrationTest {
@RegisterExtension
static final DropwizardAppExtension<Configuration> DROPWIZARD_APP_EXTENSION =
new DropwizardAppExtension<>(TestApplication.class);
private static final UUID ACCOUNT_UUID = UUID.randomUUID();
private static final byte DEVICE_ID = Device.PRIMARY_ID;
private static final String E164 = PhoneNumberUtil.getInstance()
.format(PhoneNumberUtil.getInstance().getExampleNumber("US"), PhoneNumberUtil.PhoneNumberFormat.E164);
private static final DisconnectionRequestManager disconnectionRequestManager = mock(DisconnectionRequestManager.class);
private static final MessagesManager messagesManager = mock(MessagesManager.class);
private static final AccountsManager accountsManager = mock(AccountsManager.class);
private static final Account account = mock(Account.class);
private static final Device device = mock(Device.class);
private WebSocketClient client;
@BeforeEach
void setUp() throws Exception {
reset(messagesManager, disconnectionRequestManager, accountsManager, account, device);
when(messagesManager.mayHaveMessages(any(), any())).thenReturn(CompletableFuture.completedFuture(false));
when(account.getNumber()).thenReturn(E164);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(ACCOUNT_UUID);
when(account.getDevice(DEVICE_ID)).thenReturn(Optional.of(device));
when(device.getId()).thenReturn(DEVICE_ID);
when(accountsManager.getByAccountIdentifier(ACCOUNT_UUID)).thenReturn(Optional.of(account));
client = new WebSocketClient();
client.start();
}
@AfterEach
void tearDown() throws Exception {
client.stop();
}
public static class TestApplication extends Application<Configuration> {
@Override
public void run(final Configuration configuration, final Environment environment) {
final AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(
accountsManager,
mock(ReceiptSender.class),
messagesManager,
new MessageMetrics(),
mock(PushNotificationManager.class),
mock(PushNotificationScheduler.class),
disconnectionRequestManager,
Schedulers.boundedElastic(),
() -> mock(AsnInfoProvider.class),
mock(ClientReleaseManager.class),
mock(MessageDeliveryLoopMonitor.class),
mock(ExperimentEnrollmentManager.class));
final WebSocketEnvironment<AuthenticatedDevice> webSocketEnvironment =
new WebSocketEnvironment<>(environment, new WebSocketConfiguration());
webSocketEnvironment.setAuthenticator(_ ->
Optional.of(new AuthenticatedDevice(ACCOUNT_UUID, DEVICE_ID, Instant.now())));
webSocketEnvironment.setConnectListener(connectListener);
final WebSocketResourceProviderFactory<AuthenticatedDevice> webSocketServlet =
new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedDevice.class,
REMOTE_ADDRESS_ATTRIBUTE_NAME);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(),
(servletContext, container) -> {
container.addMapping("/websocket", webSocketServlet);
PriorityFilter.ensureFilter(servletContext, new RemoteAddressFilter());
});
}
}
@Test
void messagesDeliveredOnlyWhenHeaderAbsent() throws Exception {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
final MessageStream messageStream = mock(MessageStream.class);
when(messageStream.getMessages()).thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.just(
new MessageStreamEntry.Envelope(envelope),
new MessageStreamEntry.QueueEmpty())));
when(messageStream.acknowledgeMessage(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(null));
when(messagesManager.getMessages(ACCOUNT_UUID, device)).thenReturn(messageStream);
final URI uri = URI.create("ws://127.0.0.1:" + DROPWIZARD_APP_EXTENSION.getLocalPort() + "/websocket");
final TestWebsocketListener disabledListener = new TestWebsocketListener();
final ClientUpgradeRequest disabledRequest = new ClientUpgradeRequest(uri);
disabledRequest.setHeader(WebsocketHeaders.X_SIGNAL_DISABLE_MESSAGES, "true");
try (Session _ = client.connect(disabledListener, disabledRequest).get(5, TimeUnit.SECONDS)) {
assertThrows(TimeoutException.class, () -> disabledListener.queueEmptyFuture().get(10, TimeUnit.MILLISECONDS));
assertTrue(disabledListener.getReceivedEnvelopes().isEmpty());
assertFalse(disabledListener.queueEmptyFuture().isDone());
}
final TestWebsocketListener enabledListener = new TestWebsocketListener();
try (Session ignored = client.connect(enabledListener, uri).get(5, TimeUnit.SECONDS)) {
enabledListener.queueEmptyFuture().get(5, TimeUnit.SECONDS);
assertEquals(1, enabledListener.getReceivedEnvelopes().size());
assertEquals(
UUIDUtil.toByteString(messageGuid),
enabledListener.getReceivedEnvelopes().getFirst().getServerGuid());
}
}
@Test
void allDisconnected() throws Exception {
final MessageStream messageStream = mock(MessageStream.class);
when(messageStream.getMessages())
.thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.never()));
when(messagesManager.getMessages(ACCOUNT_UUID, device)).thenReturn(messageStream);
final URI uri = URI.create("ws://127.0.0.1:" + DROPWIZARD_APP_EXTENSION.getLocalPort() + "/websocket");
final TestWebsocketListener disabledListener = new TestWebsocketListener();
final TestWebsocketListener enabledListener = new TestWebsocketListener();
final ClientUpgradeRequest disabledRequest = new ClientUpgradeRequest(uri);
disabledRequest.setHeader(WebsocketHeaders.X_SIGNAL_DISABLE_MESSAGES, "true");
final ArgumentCaptor<DisconnectionRequestListener> captor =
ArgumentCaptor.forClass(DisconnectionRequestListener.class);
client.connect(disabledListener, disabledRequest).get(5, TimeUnit.SECONDS);
client.connect(enabledListener, uri).get(5, TimeUnit.SECONDS);
// Simulate a disconnection request
verify(disconnectionRequestManager, timeout(1000).times(2))
.addListener(eq(ACCOUNT_UUID), eq(DEVICE_ID), captor.capture());
captor.getAllValues().forEach(listener -> listener.handleDisconnectionRequest());
assertEquals(4401, disabledListener.closeFuture().join());
assertEquals(4401, enabledListener.closeFuture().join());
}
private static MessageProtos.Envelope generateRandomMessage(final UUID messageGuid) {
final long timestamp = System.currentTimeMillis();
return MessageProtos.Envelope.newBuilder()
.setClientTimestamp(timestamp)
.setServerTimestamp(timestamp)
.setContent(ByteString.copyFromUtf8(RandomStringUtils.secure().nextAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setServerGuid(UUIDUtil.toByteString(messageGuid))
.setSourceServiceId(new AciServiceIdentifier(UUID.randomUUID()).toCompactByteString())
.setDestinationServiceId(new AciServiceIdentifier(ACCOUNT_UUID).toCompactByteString())
.build();
}
}

View File

@ -21,10 +21,13 @@ import java.util.Optional;
import java.util.UUID;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.asn.AsnInfoProvider;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
@ -58,6 +61,7 @@ class AuthenticatedConnectListenerTest {
disconnectionRequestManager,
() -> mock(AsnInfoProvider.class),
mock(ClientReleaseManager.class),
mock(MessageMetrics.class),
(_, _, _) -> authenticatedWebSocketConnection);
final Device device = mock(Device.class);
@ -83,10 +87,18 @@ class AuthenticatedConnectListenerTest {
authenticatedConnectListener.onWebSocketConnect(webSocketSessionContext);
verify(disconnectionRequestManager).addListener(ACCOUNT_IDENTIFIER, DEVICE_ID, authenticatedWebSocketConnection);
final ArgumentCaptor<DisconnectionRequestListener> disconnectListener = ArgumentCaptor.forClass(DisconnectionRequestListener.class);
final ArgumentCaptor<WebSocketSessionContext.WebSocketEventListener> closeListener =
ArgumentCaptor.forClass(WebSocketSessionContext.WebSocketEventListener.class);
verify(disconnectionRequestManager).addListener(eq(ACCOUNT_IDENTIFIER), eq(DEVICE_ID), disconnectListener.capture());
// We expect one call from AuthenticatedConnectListener itself and one from OpenWebSocketCounter
verify(webSocketSessionContext, times(2)).addWebsocketClosedListener(any());
verify(webSocketSessionContext, times(2)).addWebsocketClosedListener(closeListener.capture());
verify(authenticatedWebSocketConnection).start();
// Verify that if we run the close listeners, the disconnection listener gets removed
closeListener.getAllValues()
.forEach(c -> c.onWebSocketClose(webSocketSessionContext, 1011, "test"));
verify(disconnectionRequestManager).removeListener(ACCOUNT_IDENTIFIER, DEVICE_ID, disconnectListener.getValue());
}
@Test
@ -118,12 +130,21 @@ class AuthenticatedConnectListenerTest {
authenticatedConnectListener.onWebSocketConnect(webSocketSessionContext);
verify(disconnectionRequestManager).addListener(ACCOUNT_IDENTIFIER, DEVICE_ID, authenticatedWebSocketConnection);
// We expect one call from AuthenticatedConnectListener itself and one from OpenWebSocketCounter
verify(webSocketSessionContext, times(2)).addWebsocketClosedListener(any());
final ArgumentCaptor<DisconnectionRequestListener> disconnectListener = ArgumentCaptor.forClass(DisconnectionRequestListener.class);
final ArgumentCaptor<WebSocketSessionContext.WebSocketEventListener> closeListener =
ArgumentCaptor.forClass(WebSocketSessionContext.WebSocketEventListener.class);
// We expect one call from OpenWebSocketCounter and one from AuthenticatedConnectListener itself
verify(webSocketSessionContext, times(2)).addWebsocketClosedListener(closeListener.capture());
verify(disconnectionRequestManager).addListener(eq(ACCOUNT_IDENTIFIER), eq(DEVICE_ID), disconnectListener.capture());
verify(authenticatedWebSocketConnection).start();
verify(webSocketClient).close(eq(1011), anyString());
// Verify that if we run the close listeners, the disconnection listener gets removed
closeListener.getAllValues()
.forEach(c -> c.onWebSocketClose(webSocketSessionContext, 1011, "test"));
verify(disconnectionRequestManager).removeListener(ACCOUNT_IDENTIFIER, DEVICE_ID, disconnectListener.getValue());
}
@Test

View File

@ -97,7 +97,12 @@ public class WebSocketClient {
public boolean shouldDeliverStories() {
String value = session.getUpgradeRequest().getHeader(WebsocketHeaders.X_SIGNAL_RECEIVE_STORIES);
return WebsocketHeaders.parseReceiveStoriesHeader(value);
return WebsocketHeaders.parseBooleanHeader(value);
}
public boolean shouldDisableMessages() {
final String value = session.getUpgradeRequest().getHeader(WebsocketHeaders.X_SIGNAL_DISABLE_MESSAGES);
return WebsocketHeaders.parseBooleanHeader(value);
}
private long generateRequestId() {

View File

@ -6,7 +6,13 @@ package org.whispersystems.websocket;
public class WebsocketHeaders {
public final static String X_SIGNAL_RECEIVE_STORIES = "X-Signal-Receive-Stories";
public static boolean parseReceiveStoriesHeader(String s) {
public final static String X_SIGNAL_DISABLE_MESSAGES = "X-Signal-Disable-Messages";
/// Parse a boolean header value
///
/// @param s the value of an HTTP header
/// @return true if the header is "true", otherwise false
public static boolean parseBooleanHeader(String s) {
return "true".equals(s);
}
}