Checkpoint after every record is processed
Some checks failed
CI / test (push) Has been cancelled

This commit is contained in:
Katherine 2026-03-12 10:56:45 -04:00 committed by GitHub
parent 563f41039c
commit 5fb822474e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 53 additions and 199 deletions

View File

@ -128,7 +128,6 @@ type StreamConfig struct {
AciStreamName envstr `yaml:"aci-stream-name"`
E164StreamName envstr `yaml:"e164-stream-name"`
UsernameStreamName envstr `yaml:"username-stream-name"`
CheckpointSize uint `yaml:"checkpoint-size"`
NewStreams []string `yaml:"new-streams"`
@ -275,9 +274,6 @@ func Read(filename string) (*Config, error) {
if parsed.StreamConfig.UsernameStreamName == "" {
return nil, fmt.Errorf("field not provided: stream.username-stream-name")
}
if parsed.StreamConfig.CheckpointSize == 0 {
return nil, fmt.Errorf("stream.checkpoint-size cannot be 0")
}
}
if parsed.APIConfig.JitterPercent < 0 || parsed.APIConfig.JitterPercent > 100 {

View File

@ -216,7 +216,7 @@ func main() {
util.Log().Infof("%s stream start timestamp: %s", streamName, streamStartTimestamp.Format(time.RFC3339))
}
go func() {
s.run(ctx, streamName, streamStartTimestamp, updateHandler, updateFromStreamFunc, config.StreamConfig.CheckpointSize)
s.run(ctx, streamName, streamStartTimestamp, updateHandler, updateFromStreamFunc)
}()
}

View File

@ -13,8 +13,6 @@ import (
"fmt"
"math"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
@ -59,106 +57,6 @@ type kinesisLogger struct{}
func (kl kinesisLogger) Log(v ...any) { util.Log().Infof("%s", fmt.Sprintln(v...)) }
// shardMap implements locking around a map from shard id to shard state, used
// to coordinate many updater goroutines processing Kinesis records.
type shardMap struct {
mutex sync.Mutex
// done is set to true when the Kinesis stream is being shutdown. This
// prevents any further updates from being processed.
done bool
// shards is a map from shard id to shard state.
shards map[string]*shardState
}
func newShardMap() *shardMap {
return &shardMap{
mutex: sync.Mutex{},
done: false,
shards: make(map[string]*shardState),
}
}
// start starts a new update for the given shard id. It returns the shard state.
// If the stream is being shutdown, it returns nil.
func (sm *shardMap) start(id string) *shardState {
sm.mutex.Lock()
defer sm.mutex.Unlock()
if sm.done {
return nil
}
if _, ok := sm.shards[id]; !ok {
sm.shards[id] = &shardState{}
}
state := sm.shards[id]
state.start()
return state
}
// finish stops new updates from being processed and waits for all existing
// updates to finish.
func (sm *shardMap) finish() {
sm.mutex.Lock()
defer sm.mutex.Unlock()
sm.done = true
for _, shard := range sm.shards {
shard.waitGroup.Wait()
}
}
// shardState is the update state for an individual Kinesis shard.
type shardState struct {
// sinceLast is the number of records processed since the last checkpoint.
sinceLast int
// waitGroup tracks when all pending updates have been processed.
waitGroup sync.WaitGroup
// searchKeyLocks is a map from a search key to chan struct{}. It prevents multiple updates
// from being processed for the same search key simultaneously.
searchKeyLocks sync.Map
// didFail is set to true if any updates failed to process.
didFail atomic.Bool
}
func (ss *shardState) start() {
ss.sinceLast++
ss.waitGroup.Add(1)
}
// wait waits for all pending updates to finish. It returns true if any of the
// updates failed to process.
func (ss *shardState) wait() bool {
ss.waitGroup.Wait()
ss.sinceLast = 0
return ss.didFail.Load()
}
// lockSearchKey blocks until it is able to get an exclusive lock on the given search key. No
// other goroutines are able to obtain a lock until `unlock` is called.
func (ss *shardState) lockSearchKey(searchKey []byte) (unlock func()) {
searchKeyString := fmt.Sprintf("%x", searchKey)
ch := make(chan struct{})
for {
existing, locked := ss.searchKeyLocks.LoadOrStore(searchKeyString, ch)
if locked {
// This search key is already locked. Wait for it to be unlocked and retry.
<-existing.(chan struct{})
continue
}
return func() {
ss.searchKeyLocks.CompareAndDelete(searchKeyString, ch)
close(ch)
}
}
}
func (ss *shardState) failed() { ss.didFail.Store(true) }
func (ss *shardState) done() { ss.waitGroup.Done() }
type Streamer struct {
config *config.APIConfig
tx db.TransparencyStore
@ -166,18 +64,11 @@ type Streamer struct {
// run runs the streamer, blocking forever.
func (s *Streamer) run(ctx context.Context, name string, startAtTimestamp *time.Time, updateHandler *KtUpdateHandler,
updateFunc updateFunc, checkpointSize uint) {
updateFunc updateFunc) {
i := 0
for {
// Create a new context and shard map for each run.
// Create a new context for each run.
runCtx, cancel := context.WithCancel(ctx)
shards := newShardMap()
// Note on thread safety: The Kinesis consumer library will use one
// goroutine per shard to scan. As such, a mutex is required to lookup shard
// state from the `shards` map because many shards may be read/written to
// the map in parallel. But the returned shardState struct can then be used
// without a mutex because there is only one goroutine working with it.
opts := []consumer.Option{
consumer.WithLogger(kinesisLogger{}),
@ -195,65 +86,61 @@ func (s *Streamer) run(ctx context.Context, name string, startAtTimestamp *time.
c, err := consumer.New(name, opts...)
if err != nil {
util.Log().Errorf("stream consumer initialization error: %v", err)
time.Sleep(5 * time.Second)
util.Log().Errorf("%s stream consumer initialization error: %v", name, err)
cancel()
continue
}
startAtTimestamp = nil
err = c.Scan(runCtx, func(r *consumer.Record) error {
// If start returns nil, the stream is shutting down and we should exit
state := shards.start(r.ShardID)
if state == nil {
return consumer.ErrSkipCheckpoint
}
go func(ctx context.Context, data []byte, state *shardState) {
defer state.done()
for {
select {
case <-ctx.Done():
state.failed()
return
default:
err := updateFunc(ctx, data, state, updateHandler, logUpdater)
if err != nil {
util.Log().Infof("failed to update entry from stream: %v", err)
metrics.IncrCounter([]string{withinStream, "errors"}, 1)
time.Sleep(3 * time.Second)
} else {
return
recordIteration := 0
// Loop until we successfully process the record, or the context is closed.
for {
select {
case <-runCtx.Done():
return consumer.ErrSkipCheckpoint
default:
err := updateFunc(runCtx, dup(r.Data), updateHandler, logUpdater)
if err != nil {
metrics.IncrCounterWithLabels([]string{withinStream, "errors"}, 1,
[]metrics.Label{{Name: "shardId", Value: r.ShardID}, {Name: "stream", Value: name}})
// Cap backoff at 30 seconds
sleep := time.Duration(math.Min(60, math.Pow(2, float64(recordIteration)))) * 500 * time.Millisecond
util.Log().Warnf(
"failed to update entry from stream: %v. streamName: %s, shardId: %s, seqNum: %s. iteration %d, sleeping %s.",
err, name, r.ShardID, *r.SequenceNumber, recordIteration, sleep)
select {
case <-runCtx.Done():
return consumer.ErrSkipCheckpoint
case <-time.After(sleep):
}
recordIteration++
continue
}
// Checkpoint after we finish processing each record
return nil
}
}(runCtx, dup(r.Data), state)
// If only a few entries have been sequenced from this shard, move on.
if uint(state.sinceLast) < checkpointSize {
return consumer.ErrSkipCheckpoint
}
// If many entries have been sequenced, we need to checkpoint. First
// wait for all processing updates to complete.
if failed := state.wait(); failed {
return consumer.ErrSkipCheckpoint
}
return nil
})
util.Log().Errorf("stream consumer error: %v", err)
// We only reach this point if c.Scan returns an error.
// In this case, clean up the current context, sleep with an exponential backoff,
// and wait for all spawned goroutines to exit.
if err != nil {
util.Log().Errorf("%s stream scan error: %v", name, err)
}
// Clean up the current context in case we iterate again.
cancel()
// Cap the backoff at 60 seconds
// If the context is closed and c.Scan exited with no error, don't iterate. This handles
// normal server shutdown behavior.
if ctx.Err() != nil && err == nil {
return
}
// Otherwise, sleep with exponential backoff (capped at 60 seconds).
delay := time.Duration(math.Min(60, math.Pow(2, float64(i)))) * time.Second
util.Log().Infof("iteration %d of stream consumer, sleeping %s", i, delay)
util.Log().Infof("iteration %d of %s stream consumer, sleeping %s", i, name, delay)
time.Sleep(delay)
// Ensure that all update goroutines have exited before restarting the consumer
shards.finish()
i++
}
}
@ -380,19 +267,17 @@ type mappingPair struct {
Type string
}
func update(ctx context.Context, state *shardState, updateHandler *KtUpdateHandler, updater Updater, pair *mappingPair) error {
func update(ctx context.Context, updateHandler *KtUpdateHandler, updater Updater, pair *mappingPair) error {
if pair.PrevKey == nil && pair.NextKey == nil {
// This should never happen, but we want to know about it if it does
metrics.IncrCounterWithLabels([]string{"stream_empty_pair"}, 1, []metrics.Label{{Name: "search_key_type", Value: pair.Type}})
return nil
} else if pair.NextKey == nil {
defer state.lockSearchKey(pair.PrevKey)()
if err := updater.update(ctx, withinStream,
pair.PrevKey, tombstoneBytes, updateHandler, marshalValue(pair.PrevVal)); err != nil {
return fmt.Errorf("updating %s: %w", pair.Type, err)
}
} else {
defer state.lockSearchKey(pair.NextKey)()
if !bytes.Equal(pair.PrevVal, pair.NextVal) {
if err := updater.update(ctx, withinStream,
pair.NextKey, marshalValue(pair.NextVal), updateHandler, nil); err != nil {
@ -418,7 +303,6 @@ type streamPair[T SearchKey] struct {
func updateFromStream[T SearchKey](
ctx context.Context,
data []byte,
state *shardState,
updateHandler *KtUpdateHandler,
updater Updater,
streamType string,
@ -440,7 +324,7 @@ func updateFromStream[T SearchKey](
nextKey, nextVal = extractKeyVal(pair.Next)
}
return update(ctx, state, updateHandler, updater, &mappingPair{
return update(ctx, updateHandler, updater, &mappingPair{
PrevKey: prevKey,
PrevVal: prevVal,
NextKey: nextKey,
@ -454,10 +338,10 @@ type e164 struct {
ACI []byte `json:"aci"`
}
type updateFunc func(context.Context, []byte, *shardState, *KtUpdateHandler, Updater) error
type updateFunc func(context.Context, []byte, *KtUpdateHandler, Updater) error
func updateFromE164Stream(ctx context.Context, data []byte, state *shardState, updateHandler *KtUpdateHandler, updater Updater) error {
return updateFromStream(ctx, data, state, updateHandler, updater, "number",
func updateFromE164Stream(ctx context.Context, data []byte, updateHandler *KtUpdateHandler, updater Updater) error {
return updateFromStream(ctx, data, updateHandler, updater, "number",
func(e *e164) (key []byte, value []byte) {
return append([]byte{shared.NumberPrefix}, []byte(e.Number)...), e.ACI
},
@ -469,8 +353,8 @@ type usernameHash struct {
ACI []byte `json:"aci"`
}
func updateFromUsernameStream(ctx context.Context, data []byte, state *shardState, updateHandler *KtUpdateHandler, updater Updater) error {
return updateFromStream(ctx, data, state, updateHandler, updater, "usernameHash",
func updateFromUsernameStream(ctx context.Context, data []byte, updateHandler *KtUpdateHandler, updater Updater) error {
return updateFromStream(ctx, data, updateHandler, updater, "usernameHash",
func(u *usernameHash) (key []byte, value []byte) {
return append([]byte{shared.UsernameHashPrefix}, u.UsernameHash...), u.ACI
},
@ -482,8 +366,8 @@ type aci struct {
ACIIdentityKey []byte `json:"aciIdentityKey"`
}
func updateFromAciStream(ctx context.Context, data []byte, state *shardState, updateHandler *KtUpdateHandler, updater Updater) error {
return updateFromStream(ctx, data, state, updateHandler, updater, "aci",
func updateFromAciStream(ctx context.Context, data []byte, updateHandler *KtUpdateHandler, updater Updater) error {
return updateFromStream(ctx, data, updateHandler, updater, "aci",
func(a *aci) (key []byte, value []byte) {
return append([]byte{shared.AciPrefix}, a.ACI...), a.ACIIdentityKey
},

View File

@ -9,7 +9,6 @@ import (
"context"
"encoding/json"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
@ -106,30 +105,6 @@ func TestUpdateFromAciStream(t *testing.T) {
testStreamUpdate[aci](t, testUpdateAciPairs, updateFromAciStream)
}
func TestLockSearchKey(t *testing.T) {
const parallel = 5
state := &shardState{}
defer state.lockSearchKey([]byte("other"))()
counter := 0
output := make(chan int)
for range parallel {
go func() {
defer state.lockSearchKey([]byte("label"))()
output <- counter
time.Sleep(1 * time.Millisecond)
counter++
}()
}
for i := range parallel {
if res := <-output; res != i {
t.Fatal("unexpected counter read")
}
}
}
type testCase[T SearchKey] struct {
pair *streamPair[T]
expectedNumUpdates int
@ -138,7 +113,7 @@ type testCase[T SearchKey] struct {
func testStreamUpdate[T SearchKey](t *testing.T,
pairs []testCase[T],
updaterFunc func(context.Context, []byte, *shardState, *KtUpdateHandler, Updater) error) {
updaterFunc func(context.Context, []byte, *KtUpdateHandler, Updater) error) {
mockConfig, _ := config.Read(mockConfigFile)
mockTransparencyStore := db.NewMemoryTransparencyStore()
updateRequestChannel := make(chan updateRequest)
@ -147,7 +122,6 @@ func testStreamUpdate[T SearchKey](t *testing.T,
tx: mockTransparencyStore,
ch: updateRequestChannel,
}
state := &shardState{}
for _, p := range pairs {
mockUpdater := new(mockLogUpdater)
@ -161,7 +135,7 @@ func testStreamUpdate[T SearchKey](t *testing.T,
mockUpdater.On("update", mock.Anything, mock.Anything, pair.key, pair.value, mock.Anything, pair.preUpdateValue).Return(nil)
}
err = updaterFunc(context.Background(), marshaledData, state, mockUpdateHandler, mockUpdater)
err = updaterFunc(context.Background(), marshaledData, mockUpdateHandler, mockUpdater)
assert.NoError(t, err)
mockUpdater.AssertNumberOfCalls(t, "update", p.expectedNumUpdates)