This commit is contained in:
parent
563f41039c
commit
5fb822474e
@ -128,7 +128,6 @@ type StreamConfig struct {
|
||||
AciStreamName envstr `yaml:"aci-stream-name"`
|
||||
E164StreamName envstr `yaml:"e164-stream-name"`
|
||||
UsernameStreamName envstr `yaml:"username-stream-name"`
|
||||
CheckpointSize uint `yaml:"checkpoint-size"`
|
||||
|
||||
NewStreams []string `yaml:"new-streams"`
|
||||
|
||||
@ -275,9 +274,6 @@ func Read(filename string) (*Config, error) {
|
||||
if parsed.StreamConfig.UsernameStreamName == "" {
|
||||
return nil, fmt.Errorf("field not provided: stream.username-stream-name")
|
||||
}
|
||||
if parsed.StreamConfig.CheckpointSize == 0 {
|
||||
return nil, fmt.Errorf("stream.checkpoint-size cannot be 0")
|
||||
}
|
||||
}
|
||||
|
||||
if parsed.APIConfig.JitterPercent < 0 || parsed.APIConfig.JitterPercent > 100 {
|
||||
|
||||
@ -216,7 +216,7 @@ func main() {
|
||||
util.Log().Infof("%s stream start timestamp: %s", streamName, streamStartTimestamp.Format(time.RFC3339))
|
||||
}
|
||||
go func() {
|
||||
s.run(ctx, streamName, streamStartTimestamp, updateHandler, updateFromStreamFunc, config.StreamConfig.CheckpointSize)
|
||||
s.run(ctx, streamName, streamStartTimestamp, updateHandler, updateFromStreamFunc)
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
@ -13,8 +13,6 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
@ -59,106 +57,6 @@ type kinesisLogger struct{}
|
||||
|
||||
func (kl kinesisLogger) Log(v ...any) { util.Log().Infof("%s", fmt.Sprintln(v...)) }
|
||||
|
||||
// shardMap implements locking around a map from shard id to shard state, used
|
||||
// to coordinate many updater goroutines processing Kinesis records.
|
||||
type shardMap struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
// done is set to true when the Kinesis stream is being shutdown. This
|
||||
// prevents any further updates from being processed.
|
||||
done bool
|
||||
// shards is a map from shard id to shard state.
|
||||
shards map[string]*shardState
|
||||
}
|
||||
|
||||
func newShardMap() *shardMap {
|
||||
return &shardMap{
|
||||
mutex: sync.Mutex{},
|
||||
|
||||
done: false,
|
||||
shards: make(map[string]*shardState),
|
||||
}
|
||||
}
|
||||
|
||||
// start starts a new update for the given shard id. It returns the shard state.
|
||||
// If the stream is being shutdown, it returns nil.
|
||||
func (sm *shardMap) start(id string) *shardState {
|
||||
sm.mutex.Lock()
|
||||
defer sm.mutex.Unlock()
|
||||
|
||||
if sm.done {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, ok := sm.shards[id]; !ok {
|
||||
sm.shards[id] = &shardState{}
|
||||
}
|
||||
state := sm.shards[id]
|
||||
state.start()
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
// finish stops new updates from being processed and waits for all existing
|
||||
// updates to finish.
|
||||
func (sm *shardMap) finish() {
|
||||
sm.mutex.Lock()
|
||||
defer sm.mutex.Unlock()
|
||||
|
||||
sm.done = true
|
||||
for _, shard := range sm.shards {
|
||||
shard.waitGroup.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
// shardState is the update state for an individual Kinesis shard.
|
||||
type shardState struct {
|
||||
// sinceLast is the number of records processed since the last checkpoint.
|
||||
sinceLast int
|
||||
// waitGroup tracks when all pending updates have been processed.
|
||||
waitGroup sync.WaitGroup
|
||||
// searchKeyLocks is a map from a search key to chan struct{}. It prevents multiple updates
|
||||
// from being processed for the same search key simultaneously.
|
||||
searchKeyLocks sync.Map
|
||||
// didFail is set to true if any updates failed to process.
|
||||
didFail atomic.Bool
|
||||
}
|
||||
|
||||
func (ss *shardState) start() {
|
||||
ss.sinceLast++
|
||||
ss.waitGroup.Add(1)
|
||||
}
|
||||
|
||||
// wait waits for all pending updates to finish. It returns true if any of the
|
||||
// updates failed to process.
|
||||
func (ss *shardState) wait() bool {
|
||||
ss.waitGroup.Wait()
|
||||
ss.sinceLast = 0
|
||||
return ss.didFail.Load()
|
||||
}
|
||||
|
||||
// lockSearchKey blocks until it is able to get an exclusive lock on the given search key. No
|
||||
// other goroutines are able to obtain a lock until `unlock` is called.
|
||||
func (ss *shardState) lockSearchKey(searchKey []byte) (unlock func()) {
|
||||
searchKeyString := fmt.Sprintf("%x", searchKey)
|
||||
ch := make(chan struct{})
|
||||
for {
|
||||
existing, locked := ss.searchKeyLocks.LoadOrStore(searchKeyString, ch)
|
||||
if locked {
|
||||
// This search key is already locked. Wait for it to be unlocked and retry.
|
||||
<-existing.(chan struct{})
|
||||
continue
|
||||
}
|
||||
return func() {
|
||||
ss.searchKeyLocks.CompareAndDelete(searchKeyString, ch)
|
||||
close(ch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ss *shardState) failed() { ss.didFail.Store(true) }
|
||||
func (ss *shardState) done() { ss.waitGroup.Done() }
|
||||
|
||||
type Streamer struct {
|
||||
config *config.APIConfig
|
||||
tx db.TransparencyStore
|
||||
@ -166,18 +64,11 @@ type Streamer struct {
|
||||
|
||||
// run runs the streamer, blocking forever.
|
||||
func (s *Streamer) run(ctx context.Context, name string, startAtTimestamp *time.Time, updateHandler *KtUpdateHandler,
|
||||
updateFunc updateFunc, checkpointSize uint) {
|
||||
updateFunc updateFunc) {
|
||||
i := 0
|
||||
for {
|
||||
// Create a new context and shard map for each run.
|
||||
// Create a new context for each run.
|
||||
runCtx, cancel := context.WithCancel(ctx)
|
||||
shards := newShardMap()
|
||||
|
||||
// Note on thread safety: The Kinesis consumer library will use one
|
||||
// goroutine per shard to scan. As such, a mutex is required to lookup shard
|
||||
// state from the `shards` map because many shards may be read/written to
|
||||
// the map in parallel. But the returned shardState struct can then be used
|
||||
// without a mutex because there is only one goroutine working with it.
|
||||
|
||||
opts := []consumer.Option{
|
||||
consumer.WithLogger(kinesisLogger{}),
|
||||
@ -195,65 +86,61 @@ func (s *Streamer) run(ctx context.Context, name string, startAtTimestamp *time.
|
||||
|
||||
c, err := consumer.New(name, opts...)
|
||||
if err != nil {
|
||||
util.Log().Errorf("stream consumer initialization error: %v", err)
|
||||
time.Sleep(5 * time.Second)
|
||||
util.Log().Errorf("%s stream consumer initialization error: %v", name, err)
|
||||
cancel()
|
||||
continue
|
||||
}
|
||||
startAtTimestamp = nil
|
||||
|
||||
err = c.Scan(runCtx, func(r *consumer.Record) error {
|
||||
// If start returns nil, the stream is shutting down and we should exit
|
||||
state := shards.start(r.ShardID)
|
||||
if state == nil {
|
||||
return consumer.ErrSkipCheckpoint
|
||||
}
|
||||
|
||||
go func(ctx context.Context, data []byte, state *shardState) {
|
||||
defer state.done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
state.failed()
|
||||
return
|
||||
default:
|
||||
err := updateFunc(ctx, data, state, updateHandler, logUpdater)
|
||||
if err != nil {
|
||||
util.Log().Infof("failed to update entry from stream: %v", err)
|
||||
metrics.IncrCounter([]string{withinStream, "errors"}, 1)
|
||||
time.Sleep(3 * time.Second)
|
||||
} else {
|
||||
return
|
||||
recordIteration := 0
|
||||
// Loop until we successfully process the record, or the context is closed.
|
||||
for {
|
||||
select {
|
||||
case <-runCtx.Done():
|
||||
return consumer.ErrSkipCheckpoint
|
||||
default:
|
||||
err := updateFunc(runCtx, dup(r.Data), updateHandler, logUpdater)
|
||||
if err != nil {
|
||||
metrics.IncrCounterWithLabels([]string{withinStream, "errors"}, 1,
|
||||
[]metrics.Label{{Name: "shardId", Value: r.ShardID}, {Name: "stream", Value: name}})
|
||||
// Cap backoff at 30 seconds
|
||||
sleep := time.Duration(math.Min(60, math.Pow(2, float64(recordIteration)))) * 500 * time.Millisecond
|
||||
util.Log().Warnf(
|
||||
"failed to update entry from stream: %v. streamName: %s, shardId: %s, seqNum: %s. iteration %d, sleeping %s.",
|
||||
err, name, r.ShardID, *r.SequenceNumber, recordIteration, sleep)
|
||||
select {
|
||||
case <-runCtx.Done():
|
||||
return consumer.ErrSkipCheckpoint
|
||||
case <-time.After(sleep):
|
||||
}
|
||||
recordIteration++
|
||||
continue
|
||||
}
|
||||
// Checkpoint after we finish processing each record
|
||||
return nil
|
||||
}
|
||||
}(runCtx, dup(r.Data), state)
|
||||
|
||||
// If only a few entries have been sequenced from this shard, move on.
|
||||
if uint(state.sinceLast) < checkpointSize {
|
||||
return consumer.ErrSkipCheckpoint
|
||||
}
|
||||
// If many entries have been sequenced, we need to checkpoint. First
|
||||
// wait for all processing updates to complete.
|
||||
if failed := state.wait(); failed {
|
||||
return consumer.ErrSkipCheckpoint
|
||||
}
|
||||
return nil
|
||||
})
|
||||
util.Log().Errorf("stream consumer error: %v", err)
|
||||
|
||||
// We only reach this point if c.Scan returns an error.
|
||||
// In this case, clean up the current context, sleep with an exponential backoff,
|
||||
// and wait for all spawned goroutines to exit.
|
||||
if err != nil {
|
||||
util.Log().Errorf("%s stream scan error: %v", name, err)
|
||||
}
|
||||
|
||||
// Clean up the current context in case we iterate again.
|
||||
cancel()
|
||||
|
||||
// Cap the backoff at 60 seconds
|
||||
// If the context is closed and c.Scan exited with no error, don't iterate. This handles
|
||||
// normal server shutdown behavior.
|
||||
if ctx.Err() != nil && err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, sleep with exponential backoff (capped at 60 seconds).
|
||||
delay := time.Duration(math.Min(60, math.Pow(2, float64(i)))) * time.Second
|
||||
util.Log().Infof("iteration %d of stream consumer, sleeping %s", i, delay)
|
||||
util.Log().Infof("iteration %d of %s stream consumer, sleeping %s", i, name, delay)
|
||||
time.Sleep(delay)
|
||||
|
||||
// Ensure that all update goroutines have exited before restarting the consumer
|
||||
shards.finish()
|
||||
i++
|
||||
}
|
||||
}
|
||||
@ -380,19 +267,17 @@ type mappingPair struct {
|
||||
Type string
|
||||
}
|
||||
|
||||
func update(ctx context.Context, state *shardState, updateHandler *KtUpdateHandler, updater Updater, pair *mappingPair) error {
|
||||
func update(ctx context.Context, updateHandler *KtUpdateHandler, updater Updater, pair *mappingPair) error {
|
||||
if pair.PrevKey == nil && pair.NextKey == nil {
|
||||
// This should never happen, but we want to know about it if it does
|
||||
metrics.IncrCounterWithLabels([]string{"stream_empty_pair"}, 1, []metrics.Label{{Name: "search_key_type", Value: pair.Type}})
|
||||
return nil
|
||||
} else if pair.NextKey == nil {
|
||||
defer state.lockSearchKey(pair.PrevKey)()
|
||||
if err := updater.update(ctx, withinStream,
|
||||
pair.PrevKey, tombstoneBytes, updateHandler, marshalValue(pair.PrevVal)); err != nil {
|
||||
return fmt.Errorf("updating %s: %w", pair.Type, err)
|
||||
}
|
||||
} else {
|
||||
defer state.lockSearchKey(pair.NextKey)()
|
||||
if !bytes.Equal(pair.PrevVal, pair.NextVal) {
|
||||
if err := updater.update(ctx, withinStream,
|
||||
pair.NextKey, marshalValue(pair.NextVal), updateHandler, nil); err != nil {
|
||||
@ -418,7 +303,6 @@ type streamPair[T SearchKey] struct {
|
||||
func updateFromStream[T SearchKey](
|
||||
ctx context.Context,
|
||||
data []byte,
|
||||
state *shardState,
|
||||
updateHandler *KtUpdateHandler,
|
||||
updater Updater,
|
||||
streamType string,
|
||||
@ -440,7 +324,7 @@ func updateFromStream[T SearchKey](
|
||||
nextKey, nextVal = extractKeyVal(pair.Next)
|
||||
}
|
||||
|
||||
return update(ctx, state, updateHandler, updater, &mappingPair{
|
||||
return update(ctx, updateHandler, updater, &mappingPair{
|
||||
PrevKey: prevKey,
|
||||
PrevVal: prevVal,
|
||||
NextKey: nextKey,
|
||||
@ -454,10 +338,10 @@ type e164 struct {
|
||||
ACI []byte `json:"aci"`
|
||||
}
|
||||
|
||||
type updateFunc func(context.Context, []byte, *shardState, *KtUpdateHandler, Updater) error
|
||||
type updateFunc func(context.Context, []byte, *KtUpdateHandler, Updater) error
|
||||
|
||||
func updateFromE164Stream(ctx context.Context, data []byte, state *shardState, updateHandler *KtUpdateHandler, updater Updater) error {
|
||||
return updateFromStream(ctx, data, state, updateHandler, updater, "number",
|
||||
func updateFromE164Stream(ctx context.Context, data []byte, updateHandler *KtUpdateHandler, updater Updater) error {
|
||||
return updateFromStream(ctx, data, updateHandler, updater, "number",
|
||||
func(e *e164) (key []byte, value []byte) {
|
||||
return append([]byte{shared.NumberPrefix}, []byte(e.Number)...), e.ACI
|
||||
},
|
||||
@ -469,8 +353,8 @@ type usernameHash struct {
|
||||
ACI []byte `json:"aci"`
|
||||
}
|
||||
|
||||
func updateFromUsernameStream(ctx context.Context, data []byte, state *shardState, updateHandler *KtUpdateHandler, updater Updater) error {
|
||||
return updateFromStream(ctx, data, state, updateHandler, updater, "usernameHash",
|
||||
func updateFromUsernameStream(ctx context.Context, data []byte, updateHandler *KtUpdateHandler, updater Updater) error {
|
||||
return updateFromStream(ctx, data, updateHandler, updater, "usernameHash",
|
||||
func(u *usernameHash) (key []byte, value []byte) {
|
||||
return append([]byte{shared.UsernameHashPrefix}, u.UsernameHash...), u.ACI
|
||||
},
|
||||
@ -482,8 +366,8 @@ type aci struct {
|
||||
ACIIdentityKey []byte `json:"aciIdentityKey"`
|
||||
}
|
||||
|
||||
func updateFromAciStream(ctx context.Context, data []byte, state *shardState, updateHandler *KtUpdateHandler, updater Updater) error {
|
||||
return updateFromStream(ctx, data, state, updateHandler, updater, "aci",
|
||||
func updateFromAciStream(ctx context.Context, data []byte, updateHandler *KtUpdateHandler, updater Updater) error {
|
||||
return updateFromStream(ctx, data, updateHandler, updater, "aci",
|
||||
func(a *aci) (key []byte, value []byte) {
|
||||
return append([]byte{shared.AciPrefix}, a.ACI...), a.ACIIdentityKey
|
||||
},
|
||||
|
||||
@ -9,7 +9,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
@ -106,30 +105,6 @@ func TestUpdateFromAciStream(t *testing.T) {
|
||||
testStreamUpdate[aci](t, testUpdateAciPairs, updateFromAciStream)
|
||||
}
|
||||
|
||||
func TestLockSearchKey(t *testing.T) {
|
||||
const parallel = 5
|
||||
|
||||
state := &shardState{}
|
||||
defer state.lockSearchKey([]byte("other"))()
|
||||
|
||||
counter := 0
|
||||
output := make(chan int)
|
||||
for range parallel {
|
||||
go func() {
|
||||
defer state.lockSearchKey([]byte("label"))()
|
||||
output <- counter
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
counter++
|
||||
}()
|
||||
}
|
||||
|
||||
for i := range parallel {
|
||||
if res := <-output; res != i {
|
||||
t.Fatal("unexpected counter read")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type testCase[T SearchKey] struct {
|
||||
pair *streamPair[T]
|
||||
expectedNumUpdates int
|
||||
@ -138,7 +113,7 @@ type testCase[T SearchKey] struct {
|
||||
|
||||
func testStreamUpdate[T SearchKey](t *testing.T,
|
||||
pairs []testCase[T],
|
||||
updaterFunc func(context.Context, []byte, *shardState, *KtUpdateHandler, Updater) error) {
|
||||
updaterFunc func(context.Context, []byte, *KtUpdateHandler, Updater) error) {
|
||||
mockConfig, _ := config.Read(mockConfigFile)
|
||||
mockTransparencyStore := db.NewMemoryTransparencyStore()
|
||||
updateRequestChannel := make(chan updateRequest)
|
||||
@ -147,7 +122,6 @@ func testStreamUpdate[T SearchKey](t *testing.T,
|
||||
tx: mockTransparencyStore,
|
||||
ch: updateRequestChannel,
|
||||
}
|
||||
state := &shardState{}
|
||||
|
||||
for _, p := range pairs {
|
||||
mockUpdater := new(mockLogUpdater)
|
||||
@ -161,7 +135,7 @@ func testStreamUpdate[T SearchKey](t *testing.T,
|
||||
mockUpdater.On("update", mock.Anything, mock.Anything, pair.key, pair.value, mock.Anything, pair.preUpdateValue).Return(nil)
|
||||
}
|
||||
|
||||
err = updaterFunc(context.Background(), marshaledData, state, mockUpdateHandler, mockUpdater)
|
||||
err = updaterFunc(context.Background(), marshaledData, mockUpdateHandler, mockUpdater)
|
||||
|
||||
assert.NoError(t, err)
|
||||
mockUpdater.AssertNumberOfCalls(t, "update", p.expectedNumUpdates)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user