b i g b a n g

This commit is contained in:
Moxie Marlinspike 2017-06-23 14:52:53 -07:00
commit d9b15e9519
13 changed files with 719 additions and 0 deletions

3
.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
.idea
*.iml
target/

75
pom.xml Normal file
View File

@ -0,0 +1,75 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.whispersystems</groupId>
<artifactId>dispatch</artifactId>
<version>1.0</version>
<properties>
<dropwizard.version>1.1.0</dropwizard.version>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<dependencies>
<dependency>
<groupId>io.dropwizard</groupId>
<artifactId>dropwizard-core</artifactId>
<version>${dropwizard.version}</version>
</dependency>
<dependency>
<groupId>io.dropwizard</groupId>
<artifactId>dropwizard-testing</artifactId>
<version>${dropwizard.version}</version>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>2.7.22</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.2</version>
<configuration>
<source>1.7</source>
<target>1.7</target>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-source-plugin</artifactId>
<version>2.2.1</version>
<executions>
<execution>
<id>attach-sources</id>
<goals>
<goal>jar</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<version>2.4</version>
<configuration>
<archive>
<manifest>
<addDefaultImplementationEntries>true</addDefaultImplementationEntries>
</manifest>
</archive>
</configuration>
</plugin>
</plugins>
</build>
</project>

View File

@ -0,0 +1,7 @@
package org.whispersystems.dispatch;
public interface DispatchChannel {
public void onDispatchMessage(String channel, byte[] message);
public void onDispatchSubscribed(String channel);
public void onDispatchUnsubscribed(String channel);
}

View File

@ -0,0 +1,172 @@
package org.whispersystems.dispatch;
import com.google.common.base.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.dispatch.io.RedisPubSubConnectionFactory;
import org.whispersystems.dispatch.redis.PubSubConnection;
import org.whispersystems.dispatch.redis.PubSubReply;
import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
public class DispatchManager extends Thread {
private final Logger logger = LoggerFactory.getLogger(DispatchManager.class);
private final Executor executor = Executors.newCachedThreadPool();
private final Map<String, DispatchChannel> subscriptions = new ConcurrentHashMap<>();
private final Optional<DispatchChannel> deadLetterChannel;
private final RedisPubSubConnectionFactory redisPubSubConnectionFactory;
private PubSubConnection pubSubConnection;
private volatile boolean running;
public DispatchManager(RedisPubSubConnectionFactory redisPubSubConnectionFactory,
Optional<DispatchChannel> deadLetterChannel)
{
this.redisPubSubConnectionFactory = redisPubSubConnectionFactory;
this.deadLetterChannel = deadLetterChannel;
}
@Override
public void start() {
this.pubSubConnection = redisPubSubConnectionFactory.connect();
this.running = true;
super.start();
}
public void shutdown() {
this.running = false;
this.pubSubConnection.close();
}
public synchronized void subscribe(String name, DispatchChannel dispatchChannel) {
Optional<DispatchChannel> previous = Optional.fromNullable(subscriptions.get(name));
subscriptions.put(name, dispatchChannel);
try {
pubSubConnection.subscribe(name);
} catch (IOException e) {
logger.warn("Subscription error", e);
}
if (previous.isPresent()) {
dispatchUnsubscription(name, previous.get());
}
}
public synchronized void unsubscribe(String name, DispatchChannel channel) {
Optional<DispatchChannel> subscription = Optional.fromNullable(subscriptions.get(name));
if (subscription.isPresent() && subscription.get() == channel) {
subscriptions.remove(name);
try {
pubSubConnection.unsubscribe(name);
} catch (IOException e) {
logger.warn("Unsubscribe error", e);
}
dispatchUnsubscription(name, subscription.get());
}
}
public boolean hasSubscription(String name) {
return subscriptions.containsKey(name);
}
@Override
public void run() {
while (running) {
try {
PubSubReply reply = pubSubConnection.read();
switch (reply.getType()) {
case UNSUBSCRIBE: break;
case SUBSCRIBE: dispatchSubscribe(reply); break;
case MESSAGE: dispatchMessage(reply); break;
default: throw new AssertionError("Unknown pubsub reply type! " + reply.getType());
}
} catch (IOException e) {
logger.warn("***** PubSub Connection Error *****", e);
if (running) {
this.pubSubConnection.close();
this.pubSubConnection = redisPubSubConnectionFactory.connect();
resubscribeAll();
}
}
}
logger.warn("DispatchManager Shutting Down...");
}
private void dispatchSubscribe(final PubSubReply reply) {
Optional<DispatchChannel> subscription = Optional.fromNullable(subscriptions.get(reply.getChannel()));
if (subscription.isPresent()) {
dispatchSubscription(reply.getChannel(), subscription.get());
} else {
logger.info("Received subscribe event for non-existing channel: " + reply.getChannel());
}
}
private void dispatchMessage(PubSubReply reply) {
Optional<DispatchChannel> subscription = Optional.fromNullable(subscriptions.get(reply.getChannel()));
if (subscription.isPresent()) {
dispatchMessage(reply.getChannel(), subscription.get(), reply.getContent().get());
} else if (deadLetterChannel.isPresent()) {
dispatchMessage(reply.getChannel(), deadLetterChannel.get(), reply.getContent().get());
} else {
logger.warn("Received message for non-existing channel, with no dead letter handler: " + reply.getChannel());
}
}
private void resubscribeAll() {
new Thread() {
@Override
public void run() {
synchronized (DispatchManager.this) {
try {
for (String name : subscriptions.keySet()) {
pubSubConnection.subscribe(name);
}
} catch (IOException e) {
logger.warn("***** RESUBSCRIPTION ERROR *****", e);
}
}
}
}.start();
}
private void dispatchMessage(final String name, final DispatchChannel channel, final byte[] message) {
executor.execute(new Runnable() {
@Override
public void run() {
channel.onDispatchMessage(name, message);
}
});
}
private void dispatchSubscription(final String name, final DispatchChannel channel) {
executor.execute(new Runnable() {
@Override
public void run() {
channel.onDispatchSubscribed(name);
}
});
}
private void dispatchUnsubscription(final String name, final DispatchChannel channel) {
executor.execute(new Runnable() {
@Override
public void run() {
channel.onDispatchUnsubscribed(name);
}
});
}
}

View File

@ -0,0 +1,64 @@
package org.whispersystems.dispatch.io;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
public class RedisInputStream {
private static final byte CR = 0x0D;
private static final byte LF = 0x0A;
private final InputStream inputStream;
public RedisInputStream(InputStream inputStream) {
this.inputStream = inputStream;
}
public String readLine() throws IOException {
ByteArrayOutputStream boas = new ByteArrayOutputStream();
boolean foundCr = false;
while (true) {
int character = inputStream.read();
if (character == -1) {
throw new IOException("Stream closed!");
}
boas.write(character);
if (foundCr && character == LF) break;
else if (character == CR) foundCr = true;
else if (foundCr) foundCr = false;
}
byte[] data = boas.toByteArray();
return new String(data, 0, data.length-2);
}
public byte[] readFully(int size) throws IOException {
byte[] result = new byte[size];
int offset = 0;
int remaining = result.length;
while (remaining > 0) {
int read = inputStream.read(result, offset, remaining);
if (read < 0) {
throw new IOException("Stream closed!");
}
offset += read;
remaining -= read;
}
return result;
}
public void close() throws IOException {
inputStream.close();
}
}

View File

@ -0,0 +1,9 @@
package org.whispersystems.dispatch.io;
import org.whispersystems.dispatch.redis.PubSubConnection;
public interface RedisPubSubConnectionFactory {
public PubSubConnection connect();
}

View File

@ -0,0 +1,119 @@
package org.whispersystems.dispatch.redis;
import com.google.common.base.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.dispatch.io.RedisInputStream;
import org.whispersystems.dispatch.redis.protocol.ArrayReplyHeader;
import org.whispersystems.dispatch.redis.protocol.IntReply;
import org.whispersystems.dispatch.redis.protocol.StringReplyHeader;
import org.whispersystems.dispatch.util.Util;
import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.net.Socket;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicBoolean;
public class PubSubConnection {
private final Logger logger = LoggerFactory.getLogger(PubSubConnection.class);
private static final byte[] UNSUBSCRIBE_TYPE = {'u', 'n', 's', 'u', 'b', 's', 'c', 'r', 'i', 'b', 'e' };
private static final byte[] SUBSCRIBE_TYPE = {'s', 'u', 'b', 's', 'c', 'r', 'i', 'b', 'e' };
private static final byte[] MESSAGE_TYPE = {'m', 'e', 's', 's', 'a', 'g', 'e' };
private static final byte[] SUBSCRIBE_COMMAND = {'S', 'U', 'B', 'S', 'C', 'R', 'I', 'B', 'E', ' ' };
private static final byte[] UNSUBSCRIBE_COMMAND = {'U', 'N', 'S', 'U', 'B', 'S', 'C', 'R', 'I', 'B', 'E', ' '};
private static final byte[] CRLF = {'\r', '\n' };
private final OutputStream outputStream;
private final RedisInputStream inputStream;
private final Socket socket;
private final AtomicBoolean closed;
public PubSubConnection(Socket socket) throws IOException {
this.socket = socket;
this.outputStream = socket.getOutputStream();
this.inputStream = new RedisInputStream(new BufferedInputStream(socket.getInputStream()));
this.closed = new AtomicBoolean(false);
}
public void subscribe(String channelName) throws IOException {
if (closed.get()) throw new IOException("Connection closed!");
byte[] command = Util.combine(SUBSCRIBE_COMMAND, channelName.getBytes(), CRLF);
outputStream.write(command);
}
public void unsubscribe(String channelName) throws IOException {
if (closed.get()) throw new IOException("Connection closed!");
byte[] command = Util.combine(UNSUBSCRIBE_COMMAND, channelName.getBytes(), CRLF);
outputStream.write(command);
}
public PubSubReply read() throws IOException {
if (closed.get()) throw new IOException("Connection closed!");
ArrayReplyHeader replyHeader = new ArrayReplyHeader(inputStream.readLine());
if (replyHeader.getElementCount() != 3) {
throw new IOException("Received array reply header with strange count: " + replyHeader.getElementCount());
}
StringReplyHeader replyTypeHeader = new StringReplyHeader(inputStream.readLine());
byte[] replyType = inputStream.readFully(replyTypeHeader.getStringLength());
inputStream.readLine();
if (Arrays.equals(SUBSCRIBE_TYPE, replyType)) return readSubscribeReply();
else if (Arrays.equals(UNSUBSCRIBE_TYPE, replyType)) return readUnsubscribeReply();
else if (Arrays.equals(MESSAGE_TYPE, replyType)) return readMessageReply();
else throw new IOException("Unknown reply type: " + new String(replyType));
}
public void close() {
try {
this.closed.set(true);
this.inputStream.close();
this.outputStream.close();
this.socket.close();
} catch (IOException e) {
logger.warn("Exception while closing", e);
}
}
private PubSubReply readMessageReply() throws IOException {
StringReplyHeader channelNameHeader = new StringReplyHeader(inputStream.readLine());
byte[] channelName = inputStream.readFully(channelNameHeader.getStringLength());
inputStream.readLine();
StringReplyHeader messageHeader = new StringReplyHeader(inputStream.readLine());
byte[] message = inputStream.readFully(messageHeader.getStringLength());
inputStream.readLine();
return new PubSubReply(PubSubReply.Type.MESSAGE, new String(channelName), Optional.of(message));
}
private PubSubReply readUnsubscribeReply() throws IOException {
String channelName = readSubscriptionReply();
return new PubSubReply(PubSubReply.Type.UNSUBSCRIBE, channelName, Optional.<byte[]>absent());
}
private PubSubReply readSubscribeReply() throws IOException {
String channelName = readSubscriptionReply();
return new PubSubReply(PubSubReply.Type.SUBSCRIBE, channelName, Optional.<byte[]>absent());
}
private String readSubscriptionReply() throws IOException {
StringReplyHeader channelNameHeader = new StringReplyHeader(inputStream.readLine());
byte[] channelName = inputStream.readFully(channelNameHeader.getStringLength());
inputStream.readLine();
IntReply subscriptionCount = new IntReply(inputStream.readLine());
return new String(channelName);
}
}

View File

@ -0,0 +1,35 @@
package org.whispersystems.dispatch.redis;
import com.google.common.base.Optional;
public class PubSubReply {
public enum Type {
MESSAGE,
SUBSCRIBE,
UNSUBSCRIBE
}
private final Type type;
private final String channel;
private final Optional<byte[]> content;
public PubSubReply(Type type, String channel, Optional<byte[]> content) {
this.type = type;
this.channel = channel;
this.content = content;
}
public Type getType() {
return type;
}
public String getChannel() {
return channel;
}
public Optional<byte[]> getContent() {
return content;
}
}

View File

@ -0,0 +1,24 @@
package org.whispersystems.dispatch.redis.protocol;
import java.io.IOException;
public class ArrayReplyHeader {
private final int elementCount;
public ArrayReplyHeader(String header) throws IOException {
if (header == null || header.length() < 2 || header.charAt(0) != '*') {
throw new IOException("Invalid array reply header: " + header);
}
try {
this.elementCount = Integer.parseInt(header.substring(1));
} catch (NumberFormatException e) {
throw new IOException(e);
}
}
public int getElementCount() {
return elementCount;
}
}

View File

@ -0,0 +1,24 @@
package org.whispersystems.dispatch.redis.protocol;
import java.io.IOException;
public class IntReply {
private final int value;
public IntReply(String reply) throws IOException {
if (reply == null || reply.length() < 2 || reply.charAt(0) != ':') {
throw new IOException("Invalid int reply: " + reply);
}
try {
this.value = Integer.parseInt(reply.substring(1));
} catch (NumberFormatException e) {
throw new IOException(e);
}
}
public int getValue() {
return value;
}
}

View File

@ -0,0 +1,24 @@
package org.whispersystems.dispatch.redis.protocol;
import java.io.IOException;
public class StringReplyHeader {
private final int stringLength;
public StringReplyHeader(String header) throws IOException {
if (header == null || header.length() < 2 || header.charAt(0) != '$') {
throw new IOException("Invalid string reply header: " + header);
}
try {
this.stringLength = Integer.parseInt(header.substring(1));
} catch (NumberFormatException e) {
throw new IOException(e);
}
}
public int getStringLength() {
return stringLength;
}
}

View File

@ -0,0 +1,36 @@
package org.whispersystems.dispatch.util;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
public class Util {
public static byte[] combine(byte[]... elements) {
try {
int sum = 0;
for (byte[] element : elements) {
sum += element.length;
}
ByteArrayOutputStream baos = new ByteArrayOutputStream(sum);
for (byte[] element : elements) {
baos.write(element);
}
return baos.toByteArray();
} catch (IOException e) {
throw new AssertionError(e);
}
}
public static void sleep(long millis) {
try {
Thread.sleep(millis);
} catch (InterruptedException e) {
throw new AssertionError(e);
}
}
}

View File

@ -0,0 +1,127 @@
package org.whispersystems.dispatch;
import com.google.common.base.Optional;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExternalResource;
import org.mockito.ArgumentCaptor;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.whispersystems.dispatch.io.RedisPubSubConnectionFactory;
import org.whispersystems.dispatch.redis.PubSubConnection;
import org.whispersystems.dispatch.redis.PubSubReply;
import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import static org.junit.Assert.assertArrayEquals;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;
public class DispatchManagerTest {
private PubSubConnection pubSubConnection;
private RedisPubSubConnectionFactory socketFactory;
private DispatchManager dispatchManager;
private PubSubReplyInputStream pubSubReplyInputStream;
@Rule
public ExternalResource resource = new ExternalResource() {
@Override
protected void before() throws Throwable {
pubSubConnection = mock(PubSubConnection.class );
socketFactory = mock(RedisPubSubConnectionFactory.class);
pubSubReplyInputStream = new PubSubReplyInputStream();
when(socketFactory.connect()).thenReturn(pubSubConnection);
when(pubSubConnection.read()).thenAnswer(new Answer<PubSubReply>() {
@Override
public PubSubReply answer(InvocationOnMock invocationOnMock) throws Throwable {
return pubSubReplyInputStream.read();
}
});
dispatchManager = new DispatchManager(socketFactory, Optional.<DispatchChannel>absent());
dispatchManager.start();
}
@Override
protected void after() {
}
};
@Test
public void testConnect() {
verify(socketFactory).connect();
}
@Test
public void testSubscribe() throws IOException {
DispatchChannel dispatchChannel = mock(DispatchChannel.class);
dispatchManager.subscribe("foo", dispatchChannel);
pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.SUBSCRIBE, "foo", Optional.<byte[]>absent()));
verify(dispatchChannel, timeout(1000)).onDispatchSubscribed(eq("foo"));
}
@Test
public void testSubscribeUnsubscribe() throws IOException {
DispatchChannel dispatchChannel = mock(DispatchChannel.class);
dispatchManager.subscribe("foo", dispatchChannel);
dispatchManager.unsubscribe("foo", dispatchChannel);
pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.SUBSCRIBE, "foo", Optional.<byte[]>absent()));
pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.UNSUBSCRIBE, "foo", Optional.<byte[]>absent()));
verify(dispatchChannel, timeout(1000)).onDispatchUnsubscribed(eq("foo"));
}
@Test
public void testMessages() throws IOException {
DispatchChannel fooChannel = mock(DispatchChannel.class);
DispatchChannel barChannel = mock(DispatchChannel.class);
dispatchManager.subscribe("foo", fooChannel);
dispatchManager.subscribe("bar", barChannel);
pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.SUBSCRIBE, "foo", Optional.<byte[]>absent()));
pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.SUBSCRIBE, "bar", Optional.<byte[]>absent()));
verify(fooChannel, timeout(1000)).onDispatchSubscribed(eq("foo"));
verify(barChannel, timeout(1000)).onDispatchSubscribed(eq("bar"));
pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.MESSAGE, "foo", Optional.of("hello".getBytes())));
pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.MESSAGE, "bar", Optional.of("there".getBytes())));
ArgumentCaptor<byte[]> captor = ArgumentCaptor.forClass(byte[].class);
verify(fooChannel, timeout(1000)).onDispatchMessage(eq("foo"), captor.capture());
assertArrayEquals("hello".getBytes(), captor.getValue());
verify(barChannel, timeout(1000)).onDispatchMessage(eq("bar"), captor.capture());
assertArrayEquals("there".getBytes(), captor.getValue());
}
private static class PubSubReplyInputStream {
private final List<PubSubReply> pubSubReplyList = new LinkedList<>();
public synchronized PubSubReply read() {
try {
while (pubSubReplyList.isEmpty()) wait();
return pubSubReplyList.remove(0);
} catch (InterruptedException e) {
throw new AssertionError(e);
}
}
public synchronized void write(PubSubReply pubSubReply) {
pubSubReplyList.add(pubSubReply);
notifyAll();
}
}
}