diff --git a/tree/transparency/test/test.go b/tree/transparency/test/test.go index 97d4512..d1da262 100644 --- a/tree/transparency/test/test.go +++ b/tree/transparency/test/test.go @@ -8,6 +8,7 @@ package test import ( "crypto/ed25519" "crypto/rand" + "errors" mrand "math/rand" "slices" "testing" @@ -71,7 +72,33 @@ func Last(store transparency.ClientStorage) *pb.Consistency { return &pb.Consistency{Last: &head.TreeSize} } +type commitFailStore struct { + db.TransparencyStore + commitCount int + failAfter int +} + +func (s *commitFailStore) Commit(head *db.TransparencyTreeHead, auditors map[string]*db.AuditorTreeHead) error { + s.commitCount++ + if s.commitCount > s.failAfter { + return errors.New("commit failed") + } + return s.TransparencyStore.Commit(head, auditors) +} + +func newCommitFailStore(inner db.TransparencyStore, failAfter int) *commitFailStore { + return &commitFailStore{ + TransparencyStore: inner, + commitCount: 0, + failAfter: failAfter, + } +} + func NewTree(t testing.TB, deploymentMode transparency.DeploymentMode) (*transparency.Tree, *MemoryClientStorage, *transparency.PrivateConfig, []ed25519.PrivateKey) { + return NewTreeWithStore(t, deploymentMode, db.NewMemoryTransparencyStore()) +} + +func NewTreeWithStore(t testing.TB, deploymentMode transparency.DeploymentMode, store db.TransparencyStore) (*transparency.Tree, *MemoryClientStorage, *transparency.PrivateConfig, []ed25519.PrivateKey) { _, sigKey, err := ed25519.GenerateKey(nil) if err != nil { t.Fatal(err) @@ -108,17 +135,17 @@ func NewTree(t testing.TB, deploymentMode transparency.DeploymentMode) (*transpa auditorPrivateKeys = []ed25519.PrivateKey{auditor1PrivateKey, auditor2PrivateKey} } - tree, err := transparency.NewTree(config, db.NewMemoryTransparencyStore()) + tree, err := transparency.NewTree(config, store) if err != nil { t.Fatal(err) } - store := &MemoryClientStorage{ + clientStore := &MemoryClientStorage{ config: config.Public(), data: make(map[string]*transparency.MonitoringData), } - return tree, store, config, auditorPrivateKeys + return tree, clientStore, config, auditorPrivateKeys } func RandomTree(tree *transparency.Tree, store transparency.ClientStorage, total int, keys, repeats []int) ([][]byte, error) { diff --git a/tree/transparency/test/transparency_test.go b/tree/transparency/test/transparency_test.go index 0a6b7dd..3fb564e 100644 --- a/tree/transparency/test/transparency_test.go +++ b/tree/transparency/test/transparency_test.go @@ -10,11 +10,13 @@ import ( "errors" mrand "math/rand" "testing" + "time" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "github.com/signalapp/keytransparency/cmd/shared" + "github.com/signalapp/keytransparency/db" "github.com/signalapp/keytransparency/tree/transparency" "github.com/signalapp/keytransparency/tree/transparency/math" "github.com/signalapp/keytransparency/tree/transparency/pb" @@ -110,6 +112,137 @@ func TestTreeWithAuditorHeads(t *testing.T) { } } +func TestTree_SetAuditorHead_FirstCommitFails(t *testing.T) { + tree, store, privateConfig, auditorPrivateKeys := NewTreeWithStore(t, + transparency.ThirdPartyAuditing, + // We want the commit to fail after 1 update + newCommitFailStore(db.NewMemoryTransparencyStore(), 1)) + + // Add a key to the new tree + key, value := random(), random() + updateReq := &pb.UpdateRequest{ + SearchKey: key, + Value: value, + Consistency: Last(store), + } + + _, err := tree.UpdateSimple(updateReq) + if err != nil { + t.Fatal(err) + } + + // Set auditor's first tree head + root, err := tree.GetLogTree().GetRoot(1) + if err != nil { + t.Fatal(err) + } + auditorHead, _, err := transparency.SignNewAuditorHead(auditorPrivateKeys[0], privateConfig.Public(), 1, root, exampleAuditorName1) + if err != nil { + t.Fatal(err) + } + err = tree.SetAuditorHead(&pb.AuditorTreeHead{ + TreeSize: auditorHead.TreeSize, + Timestamp: auditorHead.Timestamp, + Signature: auditorHead.Signature, + }, exampleAuditorName1) + if err == nil { + t.Fatal(err) + } + + // Search for the first key and check that no auditor tree head exists + searchReq := &pb.TreeSearchRequest{ + SearchKey: key, + Consistency: Last(store), + } + searchRes, err := tree.Search(searchReq) + if err != nil { + t.Fatal(err) + } + + if len(searchRes.GetTreeHead().GetFullAuditorTreeHeads()) > 0 { + t.Fatal("expected no auditor tree heads") + } +} + +func TestTree_SetAuditorHead_SecondCommitFails(t *testing.T) { + tree, store, privateConfig, auditorPrivateKeys := NewTreeWithStore(t, + transparency.ThirdPartyAuditing, + // We want the commit to fail after two updates (1 simple update, 1 set auditor head) + newCommitFailStore(db.NewMemoryTransparencyStore(), 2)) + + // Add a key to the new tree + key1, value1 := random(), random() + updateReq1 := &pb.UpdateRequest{ + SearchKey: key1, + Value: value1, + Consistency: Last(store), + } + + _, err := tree.UpdateSimple(updateReq1) + if err != nil { + t.Fatal(err) + } + + // Set auditor's first tree head + root, err := tree.GetLogTree().GetRoot(1) + if err != nil { + t.Fatal(err) + } + auditorHead1, _, err := transparency.SignNewAuditorHead(auditorPrivateKeys[0], privateConfig.Public(), 1, root, exampleAuditorName1) + if err != nil { + t.Fatal(err) + } + err = tree.SetAuditorHead(&pb.AuditorTreeHead{ + TreeSize: auditorHead1.TreeSize, + Timestamp: auditorHead1.Timestamp, + Signature: auditorHead1.Signature, + }, exampleAuditorName1) + if err != nil { + t.Fatal(err) + } + + // Set auditor's second tree head. It contains a unix timestamp in milliseconds + // so the signature will be different from the first one if we sleep for a millisecond. + time.Sleep(1 * time.Millisecond) + root, err = tree.GetLogTree().GetRoot(1) + if err != nil { + t.Fatal(err) + } + auditorHead2, _, err := transparency.SignNewAuditorHead(auditorPrivateKeys[0], privateConfig.Public(), 1, root, exampleAuditorName1) + if err != nil { + t.Fatal(err) + } + + if bytes.Equal(auditorHead1.Signature, auditorHead2.Signature) { + t.Fatal("expected auditor heads to be different") + } + + err = tree.SetAuditorHead(&pb.AuditorTreeHead{ + TreeSize: auditorHead2.TreeSize, + Timestamp: auditorHead2.Timestamp, + Signature: auditorHead2.Signature, + }, exampleAuditorName1) + if err == nil { + t.Fatal(err) + } + + // Search for the first key and check that the stored auditor tree head is the first one + searchReq := &pb.TreeSearchRequest{ + SearchKey: key1, + Consistency: Last(store), + } + searchRes, err := tree.Search(searchReq) + if err != nil { + t.Fatal(err) + } + + auditorTreeSize := searchRes.GetTreeHead().GetFullAuditorTreeHeads()[0].GetTreeHead().GetTreeSize() + signature := searchRes.GetTreeHead().GetFullAuditorTreeHeads()[0].GetTreeHead().GetSignature() + if auditorTreeSize != 1 || !bytes.Equal(signature, auditorHead1.Signature) { + t.Fatal("wrong auditor head") + } +} + func TestTree(t *testing.T) { tree, store, _, _ := NewTree(t, transparency.ContactMonitoring) diff --git a/tree/transparency/transparency.go b/tree/transparency/transparency.go index 5ef2df6..d912153 100644 --- a/tree/transparency/transparency.go +++ b/tree/transparency/transparency.go @@ -1030,14 +1030,19 @@ func (t *Tree) SetAuditorHead(head *pb.AuditorTreeHead, auditorName string) erro Consistency: consistency, } - // Store in database and return. + // Save the old value in case the database commit fails + old, exists := t.auditors[auditorName] + t.auditors[auditorName] = auditor + if err := t.tx.Commit(t.latest, t.auditors); err != nil { + if exists { + t.auditors[auditorName] = old + } else { + delete(t.auditors, auditorName) + } return err } - // Update the auditor tree head for this auditor - t.auditors[auditorName] = auditor - return nil }