Add header to disable messages on authenticated websocket
This commit is contained in:
parent
8deb5a803a
commit
ac720595e6
@ -62,7 +62,7 @@ public class WebsocketClientSession implements Session.Listener.AutoDemanding {
|
||||
log.info("received response message {}", response.getStatus());
|
||||
final CompletableFuture<WebSocketResponseMessage> future = responseFutures.remove(response.getRequestId());
|
||||
if (future == null) {
|
||||
throw new IllegalArgumentException("Received response with no matching request: {}" + response.getRequestId());
|
||||
throw new IllegalArgumentException("Received response with no matching request: " + response.getRequestId());
|
||||
}
|
||||
future.complete(response);
|
||||
}
|
||||
|
||||
@ -35,6 +35,7 @@ public final class MessageMetrics {
|
||||
|
||||
public static String GRPC_CHANNEL = "grpc";
|
||||
public static String WEBSOCKET_CHANNEL = "websocket";
|
||||
public static String MESSAGELESS_WEBSOCKET_CHANNEL = "messageless-websocket";
|
||||
|
||||
@VisibleForTesting
|
||||
static final String MISMATCHED_ACCOUNT_ENVELOPE_UUID_COUNTER_NAME = name(MessageMetrics.class,
|
||||
|
||||
@ -38,6 +38,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
|
||||
private final AccountsManager accountsManager;
|
||||
private final DisconnectionRequestManager disconnectionRequestManager;
|
||||
private final WebSocketConnectionBuilder webSocketConnectionBuilder;
|
||||
private final MessageMetrics messageMetrics;
|
||||
|
||||
private final OpenWebSocketCounter openAuthenticatedWebSocketCounter;
|
||||
private final OpenWebSocketCounter openUnauthenticatedWebSocketCounter;
|
||||
@ -66,6 +67,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
|
||||
disconnectionRequestManager,
|
||||
asnInfoProviderSupplier,
|
||||
clientReleaseManager,
|
||||
messageMetrics,
|
||||
(account, device, client) -> new WebSocketConnection(receiptSender,
|
||||
messagesManager,
|
||||
messageMetrics,
|
||||
@ -86,11 +88,13 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
|
||||
final DisconnectionRequestManager disconnectionRequestManager,
|
||||
final Supplier<AsnInfoProvider> asnInfoProviderSupplier,
|
||||
final ClientReleaseManager clientReleaseManager,
|
||||
final MessageMetrics messageMetrics,
|
||||
final WebSocketConnectionBuilder webSocketConnectionBuilder) {
|
||||
|
||||
this.accountsManager = accountsManager;
|
||||
this.disconnectionRequestManager = disconnectionRequestManager;
|
||||
this.webSocketConnectionBuilder = webSocketConnectionBuilder;
|
||||
this.messageMetrics = messageMetrics;
|
||||
|
||||
this.openAuthenticatedWebSocketCounter = new OpenWebSocketCounter("rpc-authenticated", asnInfoProviderSupplier, clientReleaseManager);
|
||||
this.openUnauthenticatedWebSocketCounter = new OpenWebSocketCounter("rpc-unauthenticated", asnInfoProviderSupplier, clientReleaseManager);
|
||||
@ -102,8 +106,10 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
|
||||
final boolean authenticated = (context.getAuthenticated() != null);
|
||||
|
||||
(authenticated ? openAuthenticatedWebSocketCounter : openUnauthenticatedWebSocketCounter).countOpenWebSocket(context);
|
||||
if (!authenticated) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (authenticated) {
|
||||
final AuthenticatedDevice auth = context.getAuthenticated(AuthenticatedDevice.class);
|
||||
|
||||
final Optional<Account> maybeAuthenticatedAccount =
|
||||
@ -119,29 +125,33 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
|
||||
return;
|
||||
}
|
||||
|
||||
final WebSocketConnection connection =
|
||||
webSocketConnectionBuilder.buildWebSocketConnection(maybeAuthenticatedAccount.get(),
|
||||
maybeAuthenticatedDevice.get(),
|
||||
context.getClient());
|
||||
final Account account = maybeAuthenticatedAccount.get();
|
||||
final Device device = maybeAuthenticatedDevice.get();
|
||||
|
||||
disconnectionRequestManager.addListener(maybeAuthenticatedAccount.get().getIdentifier(IdentityType.ACI),
|
||||
maybeAuthenticatedDevice.get().getId(),
|
||||
connection);
|
||||
final boolean disableMessages = context.getClient().shouldDisableMessages();
|
||||
|
||||
final Optional<WebSocketConnection> maybeWebSocketConnection = disableMessages
|
||||
? Optional.empty()
|
||||
: Optional.of(webSocketConnectionBuilder.buildWebSocketConnection(account, device, context.getClient()));
|
||||
|
||||
final WebSocketDisconnectionRequestListener disconnectionListener =
|
||||
new WebSocketDisconnectionRequestListener(messageMetrics, context.getClient(), disableMessages);
|
||||
|
||||
disconnectionRequestManager
|
||||
.addListener(account.getIdentifier(IdentityType.ACI), device.getId(), disconnectionListener);
|
||||
|
||||
context.addWebsocketClosedListener((_, _, _) -> {
|
||||
disconnectionRequestManager.removeListener(maybeAuthenticatedAccount.get().getIdentifier(IdentityType.ACI),
|
||||
maybeAuthenticatedDevice.get().getId(),
|
||||
connection);
|
||||
|
||||
connection.stop();
|
||||
disconnectionRequestManager
|
||||
.removeListener(account.getIdentifier(IdentityType.ACI), device.getId(), disconnectionListener);
|
||||
maybeWebSocketConnection.ifPresent(WebSocketConnection::stop);
|
||||
});
|
||||
|
||||
try {
|
||||
connection.start();
|
||||
maybeWebSocketConnection.ifPresent(WebSocketConnection::start);
|
||||
} catch (final Exception e) {
|
||||
log.warn("Failed to initialize websocket", e);
|
||||
context.getClient().close(1011, "Unexpected error initializing connection");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -15,7 +15,6 @@ import io.micrometer.core.instrument.Tags;
|
||||
import io.micrometer.core.instrument.Timer;
|
||||
import java.time.Duration;
|
||||
import java.util.Collections;
|
||||
import java.util.HexFormat;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
@ -29,7 +28,6 @@ import org.apache.commons.lang3.StringUtils;
|
||||
import org.eclipse.jetty.util.StaticException;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
|
||||
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
|
||||
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
|
||||
@ -62,7 +60,7 @@ import reactor.core.observability.micrometer.Micrometer;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.core.scheduler.Scheduler;
|
||||
|
||||
public class WebSocketConnection implements DisconnectionRequestListener {
|
||||
public class WebSocketConnection {
|
||||
|
||||
private static final Counter sendFailuresCounter = Metrics.counter(name(WebSocketConnection.class, "sendFailures"));
|
||||
|
||||
@ -333,10 +331,4 @@ public class WebSocketConnection implements DisconnectionRequestListener {
|
||||
throwable instanceof org.eclipse.jetty.io.EofException ||
|
||||
(throwable instanceof StaticException staticException && "Closed".equals(staticException.getMessage()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleDisconnectionRequest() {
|
||||
messageMetrics.measureMessageStreamDisplaced(MessageMetrics.WEBSOCKET_CHANNEL, userAgent, false);
|
||||
client.close(4401, "Reauthentication required");
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,33 @@
|
||||
/*
|
||||
* Copyright 2026 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.websocket;
|
||||
|
||||
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener;
|
||||
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
|
||||
import org.whispersystems.websocket.WebSocketClient;
|
||||
|
||||
class WebSocketDisconnectionRequestListener implements DisconnectionRequestListener {
|
||||
|
||||
private final MessageMetrics messageMetrics;
|
||||
private final String metricChannel;
|
||||
private final WebSocketClient client;
|
||||
|
||||
WebSocketDisconnectionRequestListener(final MessageMetrics messageMetrics, final WebSocketClient client, final boolean disableMessages) {
|
||||
this.messageMetrics = messageMetrics;
|
||||
this.client = client;
|
||||
this.metricChannel = disableMessages
|
||||
? MessageMetrics.MESSAGELESS_WEBSOCKET_CHANNEL
|
||||
: MessageMetrics.WEBSOCKET_CHANNEL;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleDisconnectionRequest() {
|
||||
final UserAgent userAgent = UserAgentUtil.maybeParseUserAgentString(client.getUserAgent());
|
||||
messageMetrics.measureMessageStreamDisplaced(metricChannel, userAgent, false);
|
||||
client.close(4401, "Reauthentication required");
|
||||
}
|
||||
}
|
||||
@ -4,25 +4,37 @@
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.tests.util;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.UncheckedIOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
import org.eclipse.jetty.websocket.api.Callback;
|
||||
import org.eclipse.jetty.websocket.api.Session;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||
import org.whispersystems.websocket.messages.WebSocketMessage;
|
||||
import org.whispersystems.websocket.messages.WebSocketMessageFactory;
|
||||
import org.whispersystems.websocket.messages.WebSocketRequestMessage;
|
||||
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
|
||||
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
|
||||
|
||||
public class TestWebsocketListener implements Session.Listener.AutoDemanding {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(TestWebsocketListener.class);
|
||||
private final AtomicLong requestId = new AtomicLong();
|
||||
private final CompletableFuture<Session> started = new CompletableFuture<>();
|
||||
protected final CompletableFuture<Session> started = new CompletableFuture<>();
|
||||
private final CompletableFuture<Void> queueEmpty = new CompletableFuture<>();
|
||||
private final CompletableFuture<Integer> closed = new CompletableFuture<>();
|
||||
private final List<MessageProtos.Envelope> receivedEnvelopes = new CopyOnWriteArrayList<>();
|
||||
private final ConcurrentHashMap<Long, CompletableFuture<WebSocketResponseMessage>> responseFutures = new ConcurrentHashMap<>();
|
||||
protected final WebSocketMessageFactory messageFactory;
|
||||
|
||||
@ -46,6 +58,10 @@ public class TestWebsocketListener implements Session.Listener.AutoDemanding {
|
||||
return closed;
|
||||
}
|
||||
|
||||
public CompletableFuture<Void> queueEmptyFuture() {
|
||||
return queueEmpty;
|
||||
}
|
||||
|
||||
public CompletableFuture<WebSocketResponseMessage> doGet(final String requestPath) {
|
||||
return sendRequest(requestPath, "GET", List.of("Accept: application/json"), Optional.empty());
|
||||
}
|
||||
@ -66,18 +82,55 @@ public class TestWebsocketListener implements Session.Listener.AutoDemanding {
|
||||
});
|
||||
}
|
||||
|
||||
public List<MessageProtos.Envelope> getReceivedEnvelopes() {
|
||||
return receivedEnvelopes;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onWebSocketBinary(final ByteBuffer payload, final Callback callback) {
|
||||
try {
|
||||
WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload);
|
||||
if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.RESPONSE_MESSAGE) {
|
||||
responseFutures.get(webSocketMessage.getResponseMessage().getRequestId())
|
||||
.complete(webSocketMessage.getResponseMessage());
|
||||
} else {
|
||||
throw new RuntimeException("Unexpected message type: " + webSocketMessage.getType());
|
||||
final WebSocketMessage message = messageFactory.parseMessage(payload);
|
||||
switch (message.getType()) {
|
||||
case REQUEST_MESSAGE -> {
|
||||
log.info("received request message {} {}", message.getRequestMessage().getVerb(), message.getRequestMessage().getPath());
|
||||
switch (message.getRequestMessage().getPath()) {
|
||||
case "/api/v1/message" -> acknowledge(message.getRequestMessage());
|
||||
case "/api/v1/queue/empty" -> queueEmpty.complete(null);
|
||||
default -> throw new IllegalStateException("Unexpected path: " + message.getRequestMessage().getPath());
|
||||
}
|
||||
}
|
||||
case RESPONSE_MESSAGE -> {
|
||||
final WebSocketResponseMessage response = message.getResponseMessage();
|
||||
log.info("received response message {}", response.getStatus());
|
||||
final CompletableFuture<WebSocketResponseMessage> future = responseFutures.remove(response.getRequestId());
|
||||
if (future == null) {
|
||||
throw new IllegalArgumentException("Received response with no matching request: " + response.getRequestId());
|
||||
}
|
||||
future.complete(response);
|
||||
}
|
||||
default -> throw new IllegalStateException("Unexpected message type: " + message.getType());
|
||||
}
|
||||
callback.succeed();
|
||||
} catch (final Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
log.warn("Failed to process message received over the websocket", e);
|
||||
callback.fail(e);
|
||||
started.join().close(1006, e.getMessage(), Callback.NOOP);
|
||||
}
|
||||
}
|
||||
|
||||
private void acknowledge(WebSocketRequestMessage message) {
|
||||
final byte[] envelopeBytes = message.getBody()
|
||||
.orElseThrow(() -> new UncheckedIOException(new IOException("Messages should have a body")));
|
||||
try {
|
||||
final MessageProtos.Envelope envelope = MessageProtos.Envelope.parseFrom(envelopeBytes);
|
||||
receivedEnvelopes.add(envelope);
|
||||
final Session session = started.join();
|
||||
final WebSocketMessage response = messageFactory.createResponse(message.getRequestId(), 200, "",
|
||||
Collections.emptyList(), Optional.empty());
|
||||
session.sendBinary(ByteBuffer.wrap(response.toByteArray()), Callback.NOOP);
|
||||
|
||||
} catch (InvalidProtocolBufferException e) {
|
||||
throw new UncheckedIOException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,230 @@
|
||||
/*
|
||||
* Copyright 2026 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.websocket;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyLong;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.reset;
|
||||
import static org.mockito.Mockito.timeout;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME;
|
||||
|
||||
import com.google.i18n.phonenumbers.PhoneNumberUtil;
|
||||
import com.google.protobuf.ByteString;
|
||||
import io.dropwizard.core.Application;
|
||||
import io.dropwizard.core.Configuration;
|
||||
import io.dropwizard.core.setup.Environment;
|
||||
import io.dropwizard.testing.junit5.DropwizardAppExtension;
|
||||
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
|
||||
import java.net.URI;
|
||||
import java.time.Instant;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.TimeoutException;
|
||||
import org.apache.commons.lang3.RandomStringUtils;
|
||||
import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer;
|
||||
import org.eclipse.jetty.websocket.api.Session;
|
||||
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
|
||||
import org.eclipse.jetty.websocket.client.WebSocketClient;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.Timeout;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.whispersystems.textsecuregcm.asn.AsnInfoProvider;
|
||||
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener;
|
||||
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
|
||||
import org.whispersystems.textsecuregcm.filters.PriorityFilter;
|
||||
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
|
||||
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
|
||||
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
|
||||
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
|
||||
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
|
||||
import org.whispersystems.textsecuregcm.push.ReceiptSender;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.storage.MessageStream;
|
||||
import org.whispersystems.textsecuregcm.storage.MessageStreamEntry;
|
||||
import org.whispersystems.textsecuregcm.storage.MessagesManager;
|
||||
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
|
||||
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
|
||||
import org.whispersystems.websocket.WebsocketHeaders;
|
||||
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
|
||||
import org.whispersystems.websocket.setup.WebSocketEnvironment;
|
||||
import reactor.adapter.JdkFlowAdapter;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.scheduler.Schedulers;
|
||||
|
||||
@ExtendWith(DropwizardExtensionsSupport.class)
|
||||
@Timeout(value = 15, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
|
||||
class AuthenticatedConnectListenerIntegrationTest {
|
||||
|
||||
@RegisterExtension
|
||||
static final DropwizardAppExtension<Configuration> DROPWIZARD_APP_EXTENSION =
|
||||
new DropwizardAppExtension<>(TestApplication.class);
|
||||
|
||||
private static final UUID ACCOUNT_UUID = UUID.randomUUID();
|
||||
private static final byte DEVICE_ID = Device.PRIMARY_ID;
|
||||
private static final String E164 = PhoneNumberUtil.getInstance()
|
||||
.format(PhoneNumberUtil.getInstance().getExampleNumber("US"), PhoneNumberUtil.PhoneNumberFormat.E164);
|
||||
|
||||
|
||||
private static final DisconnectionRequestManager disconnectionRequestManager = mock(DisconnectionRequestManager.class);
|
||||
private static final MessagesManager messagesManager = mock(MessagesManager.class);
|
||||
private static final AccountsManager accountsManager = mock(AccountsManager.class);
|
||||
private static final Account account = mock(Account.class);
|
||||
private static final Device device = mock(Device.class);
|
||||
|
||||
private WebSocketClient client;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() throws Exception {
|
||||
reset(messagesManager, disconnectionRequestManager, accountsManager, account, device);
|
||||
when(messagesManager.mayHaveMessages(any(), any())).thenReturn(CompletableFuture.completedFuture(false));
|
||||
when(account.getNumber()).thenReturn(E164);
|
||||
when(account.getIdentifier(IdentityType.ACI)).thenReturn(ACCOUNT_UUID);
|
||||
when(account.getDevice(DEVICE_ID)).thenReturn(Optional.of(device));
|
||||
when(device.getId()).thenReturn(DEVICE_ID);
|
||||
when(accountsManager.getByAccountIdentifier(ACCOUNT_UUID)).thenReturn(Optional.of(account));
|
||||
|
||||
client = new WebSocketClient();
|
||||
client.start();
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void tearDown() throws Exception {
|
||||
client.stop();
|
||||
}
|
||||
|
||||
public static class TestApplication extends Application<Configuration> {
|
||||
|
||||
@Override
|
||||
public void run(final Configuration configuration, final Environment environment) {
|
||||
final AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(
|
||||
accountsManager,
|
||||
mock(ReceiptSender.class),
|
||||
messagesManager,
|
||||
new MessageMetrics(),
|
||||
mock(PushNotificationManager.class),
|
||||
mock(PushNotificationScheduler.class),
|
||||
disconnectionRequestManager,
|
||||
Schedulers.boundedElastic(),
|
||||
() -> mock(AsnInfoProvider.class),
|
||||
mock(ClientReleaseManager.class),
|
||||
mock(MessageDeliveryLoopMonitor.class),
|
||||
mock(ExperimentEnrollmentManager.class));
|
||||
|
||||
final WebSocketEnvironment<AuthenticatedDevice> webSocketEnvironment =
|
||||
new WebSocketEnvironment<>(environment, new WebSocketConfiguration());
|
||||
|
||||
webSocketEnvironment.setAuthenticator(_ ->
|
||||
Optional.of(new AuthenticatedDevice(ACCOUNT_UUID, DEVICE_ID, Instant.now())));
|
||||
webSocketEnvironment.setConnectListener(connectListener);
|
||||
|
||||
final WebSocketResourceProviderFactory<AuthenticatedDevice> webSocketServlet =
|
||||
new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedDevice.class,
|
||||
REMOTE_ADDRESS_ATTRIBUTE_NAME);
|
||||
|
||||
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(),
|
||||
(servletContext, container) -> {
|
||||
container.addMapping("/websocket", webSocketServlet);
|
||||
PriorityFilter.ensureFilter(servletContext, new RemoteAddressFilter());
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void messagesDeliveredOnlyWhenHeaderAbsent() throws Exception {
|
||||
final UUID messageGuid = UUID.randomUUID();
|
||||
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
|
||||
|
||||
final MessageStream messageStream = mock(MessageStream.class);
|
||||
when(messageStream.getMessages()).thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.just(
|
||||
new MessageStreamEntry.Envelope(envelope),
|
||||
new MessageStreamEntry.QueueEmpty())));
|
||||
when(messageStream.acknowledgeMessage(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(null));
|
||||
when(messagesManager.getMessages(ACCOUNT_UUID, device)).thenReturn(messageStream);
|
||||
final URI uri = URI.create("ws://127.0.0.1:" + DROPWIZARD_APP_EXTENSION.getLocalPort() + "/websocket");
|
||||
|
||||
final TestWebsocketListener disabledListener = new TestWebsocketListener();
|
||||
final ClientUpgradeRequest disabledRequest = new ClientUpgradeRequest(uri);
|
||||
disabledRequest.setHeader(WebsocketHeaders.X_SIGNAL_DISABLE_MESSAGES, "true");
|
||||
try (Session _ = client.connect(disabledListener, disabledRequest).get(5, TimeUnit.SECONDS)) {
|
||||
assertThrows(TimeoutException.class, () -> disabledListener.queueEmptyFuture().get(10, TimeUnit.MILLISECONDS));
|
||||
assertTrue(disabledListener.getReceivedEnvelopes().isEmpty());
|
||||
assertFalse(disabledListener.queueEmptyFuture().isDone());
|
||||
}
|
||||
|
||||
final TestWebsocketListener enabledListener = new TestWebsocketListener();
|
||||
try (Session ignored = client.connect(enabledListener, uri).get(5, TimeUnit.SECONDS)) {
|
||||
enabledListener.queueEmptyFuture().get(5, TimeUnit.SECONDS);
|
||||
assertEquals(1, enabledListener.getReceivedEnvelopes().size());
|
||||
assertEquals(
|
||||
UUIDUtil.toByteString(messageGuid),
|
||||
enabledListener.getReceivedEnvelopes().getFirst().getServerGuid());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void allDisconnected() throws Exception {
|
||||
final MessageStream messageStream = mock(MessageStream.class);
|
||||
when(messageStream.getMessages())
|
||||
.thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.never()));
|
||||
when(messagesManager.getMessages(ACCOUNT_UUID, device)).thenReturn(messageStream);
|
||||
|
||||
final URI uri = URI.create("ws://127.0.0.1:" + DROPWIZARD_APP_EXTENSION.getLocalPort() + "/websocket");
|
||||
|
||||
final TestWebsocketListener disabledListener = new TestWebsocketListener();
|
||||
final TestWebsocketListener enabledListener = new TestWebsocketListener();
|
||||
final ClientUpgradeRequest disabledRequest = new ClientUpgradeRequest(uri);
|
||||
disabledRequest.setHeader(WebsocketHeaders.X_SIGNAL_DISABLE_MESSAGES, "true");
|
||||
|
||||
final ArgumentCaptor<DisconnectionRequestListener> captor =
|
||||
ArgumentCaptor.forClass(DisconnectionRequestListener.class);
|
||||
client.connect(disabledListener, disabledRequest).get(5, TimeUnit.SECONDS);
|
||||
client.connect(enabledListener, uri).get(5, TimeUnit.SECONDS);
|
||||
|
||||
// Simulate a disconnection request
|
||||
verify(disconnectionRequestManager, timeout(1000).times(2))
|
||||
.addListener(eq(ACCOUNT_UUID), eq(DEVICE_ID), captor.capture());
|
||||
captor.getAllValues().forEach(listener -> listener.handleDisconnectionRequest());
|
||||
|
||||
assertEquals(4401, disabledListener.closeFuture().join());
|
||||
assertEquals(4401, enabledListener.closeFuture().join());
|
||||
}
|
||||
|
||||
private static MessageProtos.Envelope generateRandomMessage(final UUID messageGuid) {
|
||||
final long timestamp = System.currentTimeMillis();
|
||||
return MessageProtos.Envelope.newBuilder()
|
||||
.setClientTimestamp(timestamp)
|
||||
.setServerTimestamp(timestamp)
|
||||
.setContent(ByteString.copyFromUtf8(RandomStringUtils.secure().nextAlphanumeric(256)))
|
||||
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
|
||||
.setServerGuid(UUIDUtil.toByteString(messageGuid))
|
||||
.setSourceServiceId(new AciServiceIdentifier(UUID.randomUUID()).toCompactByteString())
|
||||
.setDestinationServiceId(new AciServiceIdentifier(ACCOUNT_UUID).toCompactByteString())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
@ -21,10 +21,13 @@ import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.whispersystems.textsecuregcm.asn.AsnInfoProvider;
|
||||
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener;
|
||||
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager;
|
||||
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
|
||||
@ -58,6 +61,7 @@ class AuthenticatedConnectListenerTest {
|
||||
disconnectionRequestManager,
|
||||
() -> mock(AsnInfoProvider.class),
|
||||
mock(ClientReleaseManager.class),
|
||||
mock(MessageMetrics.class),
|
||||
(_, _, _) -> authenticatedWebSocketConnection);
|
||||
|
||||
final Device device = mock(Device.class);
|
||||
@ -83,10 +87,18 @@ class AuthenticatedConnectListenerTest {
|
||||
|
||||
authenticatedConnectListener.onWebSocketConnect(webSocketSessionContext);
|
||||
|
||||
verify(disconnectionRequestManager).addListener(ACCOUNT_IDENTIFIER, DEVICE_ID, authenticatedWebSocketConnection);
|
||||
final ArgumentCaptor<DisconnectionRequestListener> disconnectListener = ArgumentCaptor.forClass(DisconnectionRequestListener.class);
|
||||
final ArgumentCaptor<WebSocketSessionContext.WebSocketEventListener> closeListener =
|
||||
ArgumentCaptor.forClass(WebSocketSessionContext.WebSocketEventListener.class);
|
||||
verify(disconnectionRequestManager).addListener(eq(ACCOUNT_IDENTIFIER), eq(DEVICE_ID), disconnectListener.capture());
|
||||
// We expect one call from AuthenticatedConnectListener itself and one from OpenWebSocketCounter
|
||||
verify(webSocketSessionContext, times(2)).addWebsocketClosedListener(any());
|
||||
verify(webSocketSessionContext, times(2)).addWebsocketClosedListener(closeListener.capture());
|
||||
verify(authenticatedWebSocketConnection).start();
|
||||
|
||||
// Verify that if we run the close listeners, the disconnection listener gets removed
|
||||
closeListener.getAllValues()
|
||||
.forEach(c -> c.onWebSocketClose(webSocketSessionContext, 1011, "test"));
|
||||
verify(disconnectionRequestManager).removeListener(ACCOUNT_IDENTIFIER, DEVICE_ID, disconnectListener.getValue());
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -118,12 +130,21 @@ class AuthenticatedConnectListenerTest {
|
||||
|
||||
authenticatedConnectListener.onWebSocketConnect(webSocketSessionContext);
|
||||
|
||||
verify(disconnectionRequestManager).addListener(ACCOUNT_IDENTIFIER, DEVICE_ID, authenticatedWebSocketConnection);
|
||||
// We expect one call from AuthenticatedConnectListener itself and one from OpenWebSocketCounter
|
||||
verify(webSocketSessionContext, times(2)).addWebsocketClosedListener(any());
|
||||
final ArgumentCaptor<DisconnectionRequestListener> disconnectListener = ArgumentCaptor.forClass(DisconnectionRequestListener.class);
|
||||
final ArgumentCaptor<WebSocketSessionContext.WebSocketEventListener> closeListener =
|
||||
ArgumentCaptor.forClass(WebSocketSessionContext.WebSocketEventListener.class);
|
||||
|
||||
// We expect one call from OpenWebSocketCounter and one from AuthenticatedConnectListener itself
|
||||
verify(webSocketSessionContext, times(2)).addWebsocketClosedListener(closeListener.capture());
|
||||
verify(disconnectionRequestManager).addListener(eq(ACCOUNT_IDENTIFIER), eq(DEVICE_ID), disconnectListener.capture());
|
||||
verify(authenticatedWebSocketConnection).start();
|
||||
|
||||
verify(webSocketClient).close(eq(1011), anyString());
|
||||
|
||||
// Verify that if we run the close listeners, the disconnection listener gets removed
|
||||
closeListener.getAllValues()
|
||||
.forEach(c -> c.onWebSocketClose(webSocketSessionContext, 1011, "test"));
|
||||
verify(disconnectionRequestManager).removeListener(ACCOUNT_IDENTIFIER, DEVICE_ID, disconnectListener.getValue());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
@ -97,7 +97,12 @@ public class WebSocketClient {
|
||||
|
||||
public boolean shouldDeliverStories() {
|
||||
String value = session.getUpgradeRequest().getHeader(WebsocketHeaders.X_SIGNAL_RECEIVE_STORIES);
|
||||
return WebsocketHeaders.parseReceiveStoriesHeader(value);
|
||||
return WebsocketHeaders.parseBooleanHeader(value);
|
||||
}
|
||||
|
||||
public boolean shouldDisableMessages() {
|
||||
final String value = session.getUpgradeRequest().getHeader(WebsocketHeaders.X_SIGNAL_DISABLE_MESSAGES);
|
||||
return WebsocketHeaders.parseBooleanHeader(value);
|
||||
}
|
||||
|
||||
private long generateRequestId() {
|
||||
|
||||
@ -6,7 +6,13 @@ package org.whispersystems.websocket;
|
||||
public class WebsocketHeaders {
|
||||
public final static String X_SIGNAL_RECEIVE_STORIES = "X-Signal-Receive-Stories";
|
||||
|
||||
public static boolean parseReceiveStoriesHeader(String s) {
|
||||
public final static String X_SIGNAL_DISABLE_MESSAGES = "X-Signal-Disable-Messages";
|
||||
|
||||
/// Parse a boolean header value
|
||||
///
|
||||
/// @param s the value of an HTTP header
|
||||
/// @return true if the header is "true", otherwise false
|
||||
public static boolean parseBooleanHeader(String s) {
|
||||
return "true".equals(s);
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user