SecureValueRecovery2/host/peer/server_test.go
2023-05-05 16:25:12 -06:00

356 lines
8.3 KiB
Go

// Copyright 2023 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
package peer
import (
"bufio"
"bytes"
"context"
"fmt"
"net"
"testing"
"github.com/signalapp/svr2/peerid"
pb "github.com/signalapp/svr2/proto"
)
type mockEnclave struct {
msgs []*pb.UntrustedMessage
}
func (t *mockEnclave) Send(p *pb.UntrustedMessage) (*pb.EnclaveMessage, error) {
t.msgs = append(t.msgs, p)
return nil, nil
}
type receiverFixture struct {
enclave *mockEnclave
pr *PeerServer
addr string
done chan error
ctx context.Context
cancel context.CancelFunc
}
func (f *receiverFixture) close() error {
f.cancel()
return <-f.done
}
func (*receiverFixture) seq(i uint64) sequenceNumber {
return sequenceNumber{epoch: 0, seq: i}
}
func startReceiver(t *testing.T) *receiverFixture {
e := &mockEnclave{}
ctx, cancel := context.WithCancel(context.Background())
pr := NewPeerServer(ctx, peer(0), e)
ln, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
addr := ln.Addr().String()
done := make(chan error)
go func() { done <- pr.Listen(ln) }()
return &receiverFixture{
e,
pr,
addr,
done,
ctx,
cancel,
}
}
func (f *receiverFixture) dial(t *testing.T) net.Conn {
conn, err := net.Dial("tcp", f.addr)
if err != nil {
t.Fatal(err)
}
return conn
}
func (f *receiverFixture) handshake(t *testing.T, c net.Conn, initiatorId peerid.PeerID) sequenceNumber {
reader := bufio.NewReader(c)
if err := writeHello(c, initiatorId, f.pr.me); err != nil {
t.Error(err)
}
ack, err := readHelloAck(reader)
if err != nil {
t.Error(err)
}
return ack
}
func (*receiverFixture) trySendAck(c net.Conn, data []byte, seqno sequenceNumber) error {
if err := writeFramed(c, &pb.PeerConnectionMessage{
Inner: &pb.PeerConnectionMessage_Data{
Data: &pb.PeerConnectionData{
Seqno: seqno.proto(),
Msg: &pb.PeerMessage{
Inner: &pb.PeerMessage_Syn{Syn: data},
},
},
},
}); err != nil {
return err
}
ack, err := readAck(bufio.NewReader(c))
if err != nil {
return err
}
if ack != seqno {
return fmt.Errorf("readAck()=%v, wanted %v", ack, seqno)
}
return nil
}
func (f *receiverFixture) sendAck(t *testing.T, c net.Conn, data []byte, seqno sequenceNumber) {
if err := f.trySendAck(c, data, seqno); err != nil {
t.Error(err)
}
}
func TestSimpleAck(t *testing.T) {
f := startReceiver(t)
defer f.close()
conn := f.dial(t)
defer conn.Close()
ack := f.handshake(t, conn, peer(1))
if ack != (sequenceNumber{}) {
t.Errorf("handshake ack = %v, want %v", ack, 0)
}
f.sendAck(t, conn, []byte("data"), f.seq(1))
if len(f.enclave.msgs) != 1 {
t.Errorf("got %v forwarded messages, want %v", len(f.enclave.msgs), 1)
}
msg := f.enclave.msgs[0]
if d, ok := msg.GetPeerMessage().Inner.(*pb.PeerMessage_Syn); !ok {
t.Errorf("not syn")
} else if !bytes.Equal(d.Syn, []byte("data")) {
t.Errorf("got data %v, want %v", d.Syn, []byte("data"))
}
}
func TestSequenceGap(t *testing.T) {
f := startReceiver(t)
defer f.close()
conn := f.dial(t)
defer conn.Close()
if ack := f.handshake(t, conn, peer(1)); ack != (sequenceNumber{}) {
t.Errorf("handshake ack = %v, want %v", ack, 0)
}
f.sendAck(t, conn, []byte("data"), f.seq(1))
if err := f.trySendAck(conn, []byte("data"), f.seq(3)); err == nil {
t.Fatal("expected error sending ack")
}
}
func TestEpochGap(t *testing.T) {
f := startReceiver(t)
defer f.close()
conn := f.dial(t)
defer conn.Close()
if ack := f.handshake(t, conn, peer(1)); ack != (sequenceNumber{}) {
t.Errorf("handshake ack = %v, want %v", ack, 0)
}
f.sendAck(t, conn, []byte("data"), sequenceNumber{0, 1})
f.sendAck(t, conn, []byte("data"), sequenceNumber{2, 0}) // skipping epoch is ok
// skipping a message in epoch 3 is not ok
if err := f.trySendAck(conn, []byte("data"), sequenceNumber{3, 1}); err == nil {
t.Fatal("expected error sending ack")
}
}
func TestEpochAck(t *testing.T) {
f := startReceiver(t)
defer f.close()
conn := f.dial(t)
defer conn.Close()
ack := f.handshake(t, conn, peer(1))
if ack != (sequenceNumber{}) {
t.Errorf("handshake ack = %v, want %v", ack, 0)
}
f.sendAck(t, conn, []byte("data"), sequenceNumber{0, 1})
f.sendAck(t, conn, []byte("data"), sequenceNumber{2, 0})
f.sendAck(t, conn, []byte("data"), sequenceNumber{2, 1})
f.sendAck(t, conn, []byte("data"), sequenceNumber{3, 0})
if len(f.enclave.msgs) != 4 {
t.Errorf("got %v forwarded messages, want %v", len(f.enclave.msgs), 4)
}
}
func TestDisconnectAck(t *testing.T) {
f := startReceiver(t)
defer f.close()
conn := f.dial(t)
ack := f.handshake(t, conn, peer(1))
if ack != (sequenceNumber{}) {
t.Errorf("handshake ack = %v, want %v", ack, 0)
}
f.sendAck(t, conn, []byte("data"), f.seq(1))
f.sendAck(t, conn, []byte("data1"), f.seq(2))
f.sendAck(t, conn, []byte("data2"), f.seq(3))
if len(f.enclave.msgs) != 3 {
t.Errorf("got %v forwarded messages, want %v", len(f.enclave.msgs), 3)
}
conn.Close()
conn = f.dial(t)
ack = f.handshake(t, conn, peer(1))
if ack != f.seq(3) {
t.Errorf("handshake ack = %v, want %v", ack, 3)
}
f.sendAck(t, conn, []byte("data4"), f.seq(4))
if len(f.enclave.msgs) != 4 {
t.Errorf("got %v forwarded messages, want %v", len(f.enclave.msgs), 4)
}
}
func TestDisconnectEpochAck(t *testing.T) {
f := startReceiver(t)
defer f.close()
conn := f.dial(t)
ack := f.handshake(t, conn, peer(1))
if ack != (sequenceNumber{}) {
t.Errorf("handshake ack = %v, want %v", ack, 0)
}
f.sendAck(t, conn, []byte("data"), sequenceNumber{5, 0})
if len(f.enclave.msgs) != 1 {
t.Errorf("got %v forwarded messages, want %v", len(f.enclave.msgs), 3)
}
conn.Close()
conn = f.dial(t)
if ack = f.handshake(t, conn, peer(1)); ack != (sequenceNumber{5, 0}) {
t.Errorf("handshake ack = %v, want %v", ack, "5:0")
}
f.sendAck(t, conn, []byte("data4"), sequenceNumber{5, 1})
if len(f.enclave.msgs) != 2 {
t.Errorf("got %v forwarded messages, want %v", len(f.enclave.msgs), 2)
}
}
func TestResend(t *testing.T) {
f := startReceiver(t)
defer f.close()
conn := f.dial(t)
ack := f.handshake(t, conn, peer(1))
if ack != (sequenceNumber{}) {
t.Errorf("handshake ack = %v, want %v", ack, 0)
}
f.sendAck(t, conn, []byte("data"), f.seq(1))
f.sendAck(t, conn, []byte("data"), f.seq(1))
}
func TestReconnectingClient(t *testing.T) {
f := startReceiver(t)
defer f.close()
conn := f.dial(t)
ack := f.handshake(t, conn, peer(1))
if ack != (sequenceNumber{}) {
t.Errorf("handshake ack = %v, want %v", ack, 0)
}
// reconnect, server should disconnect old connection
conn2 := f.dial(t)
ack = f.handshake(t, conn2, peer(1))
if ack != (sequenceNumber{}) {
t.Errorf("handshake ack = %v, want %v", ack, 0)
}
bs := []byte{0}
_, err := conn.Read(bs)
if err == nil {
t.Error("expected disconnect for original client")
}
}
func TestMultiplePeers(t *testing.T) {
f := startReceiver(t)
defer f.close()
conn1 := f.dial(t)
ack1 := f.handshake(t, conn1, peer(1))
if ack1 != (sequenceNumber{}) {
t.Errorf("handshake ack = %v, want %v", ack1, 0)
}
conn2 := f.dial(t)
ack2 := f.handshake(t, conn2, peer(2))
if ack2 != (sequenceNumber{}) {
t.Errorf("handshake ack = %v, want %v", ack2, 0)
}
f.sendAck(t, conn1, []byte("data1"), f.seq(1))
f.sendAck(t, conn2, []byte("data2"), f.seq(1))
if len(f.enclave.msgs) != 2 {
t.Errorf("got %v forwarded messages, want %v", len(f.enclave.msgs), 2)
}
got1 := f.enclave.msgs[0].GetPeerMessage().Inner.(*pb.PeerMessage_Syn).Syn
got2 := f.enclave.msgs[1].GetPeerMessage().Inner.(*pb.PeerMessage_Syn).Syn
if !bytes.Equal(got1, []byte("data1")) {
t.Errorf("got data %v, want %v", got1, []byte("data1"))
}
if !bytes.Equal(got2, []byte("data2")) {
t.Errorf("got data %v, want %v", got2, []byte("data2"))
}
}
func TestEndToEnd(t *testing.T) {
f := startReceiver(t)
defer f.close()
sender := createSender(withLookup(serverID, f.addr))
done := make(chan error)
ctx, cancel := context.WithCancel(context.Background())
go func() { done <- sender.run(ctx) }()
msg := pb.PeerMessage{
PeerId: f.pr.me[:],
Inner: &pb.PeerMessage_Syn{Syn: []byte("data")},
}
*sender.tx.Load() <- &msg
cancel()
<-done
}
func TestRejectsNotMyPeerID(t *testing.T) {
f := startReceiver(t)
defer f.close()
conn := f.dial(t)
defer conn.Close()
reader := bufio.NewReader(conn)
if err := writeHello(conn, peer(1), peer(12 /*not me*/)); err != nil {
t.Error(err)
}
if _, err := readHelloAck(reader); err == nil {
t.Fatal("accepted connection when toPeerID did not match server")
}
}