fix: handle discrawl termination safely
This commit is contained in:
parent
0b12c3c653
commit
25b1eb878d
@ -4,6 +4,10 @@ All notable changes to `discrawl` will be documented in this file.
|
||||
|
||||
## 0.6.4 - Unreleased
|
||||
|
||||
### Fixes
|
||||
|
||||
- `discrawl` now handles SIGINT/SIGTERM by canceling active sync/import contexts so large SQLite and FTS writes can roll back and close cleanly instead of being terminated mid-transaction.
|
||||
|
||||
## 0.6.3 - 2026-05-01
|
||||
|
||||
### Fixes
|
||||
|
||||
@ -4,12 +4,17 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/steipete/discrawl/internal/cli"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if err := cli.Run(context.Background(), os.Args[1:], os.Stdout, os.Stderr); err != nil {
|
||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
err := cli.Run(ctx, os.Args[1:], os.Stdout, os.Stderr)
|
||||
stop()
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, err.Error())
|
||||
os.Exit(cli.ExitCode(err))
|
||||
}
|
||||
|
||||
@ -2,9 +2,16 @@ package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/steipete/discrawl/internal/config"
|
||||
"github.com/steipete/discrawl/internal/store"
|
||||
)
|
||||
|
||||
func TestMainHelpAndVersion(t *testing.T) {
|
||||
@ -38,3 +45,132 @@ func TestMainHelpAndVersion(t *testing.T) {
|
||||
}
|
||||
t.Fatalf("expected exit code 2, got %v", err)
|
||||
}
|
||||
|
||||
func TestMainCancelsWatchOnSIGTERM(t *testing.T) {
|
||||
if os.Getenv("DISCRAWL_MAIN_SIGNAL_CHILD") == "1" {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.toml")
|
||||
cfg := config.Default()
|
||||
cfg.DBPath = filepath.Join(dir, "discrawl.db")
|
||||
cfg.CacheDir = filepath.Join(dir, "cache")
|
||||
cfg.LogDir = filepath.Join(dir, "logs")
|
||||
cfg.Desktop.Path = filepath.Join(dir, "discord")
|
||||
requireNoError(t, os.MkdirAll(cfg.Desktop.Path, 0o755))
|
||||
requireNoError(t, config.Write(cfgPath, cfg))
|
||||
|
||||
oldArgs := os.Args
|
||||
t.Cleanup(func() { os.Args = oldArgs })
|
||||
os.Args = []string{"discrawl", "--config", cfgPath, "wiretap", "--dry-run", "--watch-every", "1s"}
|
||||
go func() {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
process, err := os.FindProcess(os.Getpid())
|
||||
if err == nil {
|
||||
_ = process.Signal(syscall.SIGTERM)
|
||||
}
|
||||
}()
|
||||
main()
|
||||
return
|
||||
}
|
||||
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
t.Fatalf("os.Executable: %v", err)
|
||||
}
|
||||
cmd := exec.CommandContext(t.Context(), exe, "-test.run=TestMainCancelsWatchOnSIGTERM")
|
||||
cmd.Env = append(os.Environ(), "DISCRAWL_MAIN_SIGNAL_CHILD=1")
|
||||
if err := cmd.Run(); err != nil {
|
||||
t.Fatalf("expected graceful SIGTERM cancellation, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMainCancelsWiretapImportOnSIGTERMWithoutCorruptingDB(t *testing.T) {
|
||||
if dir := os.Getenv("DISCRAWL_MAIN_IMPORT_SIGNAL_DIR"); dir != "" {
|
||||
runWiretapImportSignalChild(t, dir)
|
||||
return
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
t.Fatalf("os.Executable: %v", err)
|
||||
}
|
||||
cmd := exec.CommandContext(t.Context(), exe, "-test.run=TestMainCancelsWiretapImportOnSIGTERMWithoutCorruptingDB")
|
||||
cmd.Env = append(os.Environ(), "DISCRAWL_MAIN_IMPORT_SIGNAL_DIR="+dir)
|
||||
output, err := cmd.CombinedOutput()
|
||||
var exitErr *exec.ExitError
|
||||
if !errors.As(err, &exitErr) {
|
||||
t.Fatalf("expected context-canceled exit from SIGTERM, got err=%v output=%s", err, output)
|
||||
}
|
||||
if exitErr.ExitCode() != 1 {
|
||||
t.Fatalf("expected graceful exit code 1, got %d output=%s", exitErr.ExitCode(), output)
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
s, err := store.Open(ctx, filepath.Join(dir, "discrawl.db"))
|
||||
if err != nil {
|
||||
t.Fatalf("open db after SIGTERM: %v output=%s", err, output)
|
||||
}
|
||||
defer func() { _ = s.Close() }()
|
||||
_, rows, err := s.ReadOnlyQuery(ctx, "pragma quick_check")
|
||||
if err != nil {
|
||||
t.Fatalf("quick_check after SIGTERM: %v output=%s", err, output)
|
||||
}
|
||||
if len(rows) != 1 || len(rows[0]) != 1 || rows[0][0] != "ok" {
|
||||
t.Fatalf("quick_check after SIGTERM = %#v output=%s", rows, output)
|
||||
}
|
||||
}
|
||||
|
||||
func runWiretapImportSignalChild(t *testing.T, dir string) {
|
||||
t.Helper()
|
||||
|
||||
cfgPath := filepath.Join(dir, "config.toml")
|
||||
cfg := config.Default()
|
||||
cfg.DBPath = filepath.Join(dir, "discrawl.db")
|
||||
cfg.CacheDir = filepath.Join(dir, "cache")
|
||||
cfg.LogDir = filepath.Join(dir, "logs")
|
||||
cfg.Desktop.Path = filepath.Join(dir, "discord")
|
||||
cfg.Discord.TokenSource = "none"
|
||||
cfg.Share.AutoUpdate = false
|
||||
cachePath := filepath.Join(cfg.Desktop.Path, "Local Storage", "leveldb")
|
||||
requireNoError(t, os.MkdirAll(cachePath, 0o755))
|
||||
requireNoError(t, config.Write(cfgPath, cfg))
|
||||
writeLargeWiretapCache(t, filepath.Join(cachePath, "000001.log"), 50000)
|
||||
|
||||
oldArgs := os.Args
|
||||
t.Cleanup(func() { os.Args = oldArgs })
|
||||
os.Args = []string{"discrawl", "--config", cfgPath, "wiretap", "--path", cfg.Desktop.Path}
|
||||
go func() {
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
process, err := os.FindProcess(os.Getpid())
|
||||
if err == nil {
|
||||
_ = process.Signal(syscall.SIGTERM)
|
||||
}
|
||||
}()
|
||||
main()
|
||||
}
|
||||
|
||||
func writeLargeWiretapCache(t *testing.T, path string, count int) {
|
||||
t.Helper()
|
||||
|
||||
file, err := os.Create(path)
|
||||
requireNoError(t, err)
|
||||
defer func() { requireNoError(t, file.Close()) }()
|
||||
_, err = fmt.Fprintln(file, `{"id":"111111111111111117","guild_id":"999999999999999997","type":0,"name":"sigterm-import"}`)
|
||||
requireNoError(t, err)
|
||||
for i := range count {
|
||||
_, err = fmt.Fprintf(
|
||||
file,
|
||||
`{"id":"3333333333%09d","channel_id":"111111111111111117","content":"sigterm import message %d","timestamp":"2026-04-23T18:20:43Z","author":{"id":"222222222222222228","username":"alice"}}`+"\n",
|
||||
i,
|
||||
i,
|
||||
)
|
||||
requireNoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func requireNoError(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -246,11 +246,17 @@ func scan(ctx context.Context, opts Options, state scanState) (Stats, snapshot,
|
||||
collectChannelRoutes(snap, bytes.ToValidUTF8(data, nil))
|
||||
objects := extractJSONValues(bytes.ToValidUTF8(data, nil))
|
||||
for _, payload := range extractGzipPayloads(data, maxBytes) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
collectChannelRoutes(snap, bytes.ToValidUTF8(payload, nil))
|
||||
objects = append(objects, extractJSONValues(bytes.ToValidUTF8(payload, nil))...)
|
||||
}
|
||||
stats.JSONObjects += len(objects)
|
||||
for _, raw := range objects {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
var value any
|
||||
if err := json.Unmarshal(raw, &value); err != nil {
|
||||
continue
|
||||
@ -320,6 +326,9 @@ func writeSnapshot(ctx context.Context, st *store.Store, snap snapshot, prune bo
|
||||
guilds := mapValues(snap.guilds)
|
||||
sort.Slice(guilds, func(i, j int) bool { return guilds[i].ID < guilds[j].ID })
|
||||
for _, guild := range guilds {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := st.UpsertGuild(ctx, guild); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -327,6 +336,9 @@ func writeSnapshot(ctx context.Context, st *store.Store, snap snapshot, prune bo
|
||||
channels := mapValues(snap.channels)
|
||||
sort.Slice(channels, func(i, j int) bool { return channels[i].ID < channels[j].ID })
|
||||
for _, channel := range channels {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := st.UpsertChannel(ctx, channel); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -253,6 +253,9 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error)
|
||||
}
|
||||
}
|
||||
for _, table := range manifest.Tables {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return Manifest{}, err
|
||||
}
|
||||
if err := importTable(ctx, tx, opts.RepoPath, table); err != nil {
|
||||
return Manifest{}, err
|
||||
}
|
||||
@ -439,6 +442,9 @@ func exportTable(ctx context.Context, db *sql.DB, repoPath, table string) (Table
|
||||
ptrs[i] = &values[i]
|
||||
}
|
||||
for rows.Next() {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return TableManifest{}, err
|
||||
}
|
||||
if err := rows.Scan(ptrs...); err != nil {
|
||||
return TableManifest{}, fmt.Errorf("scan %s: %w", table, err)
|
||||
}
|
||||
@ -509,6 +515,9 @@ func exportEmbeddings(ctx context.Context, db *sql.DB, opts Options) (EmbeddingM
|
||||
columns := []string{"message_id", "provider", "model", "input_version", "dimensions", "embedding_blob", "embedded_at"}
|
||||
count := 0
|
||||
for rows.Next() {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return EmbeddingManifest{}, err
|
||||
}
|
||||
var (
|
||||
messageID string
|
||||
rowProv string
|
||||
@ -578,6 +587,9 @@ func importTable(ctx context.Context, tx *sql.Tx, repoPath string, table TableMa
|
||||
}
|
||||
defer func() { _ = stmt.Close() }()
|
||||
for _, rel := range files {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := importTableFile(ctx, stmt, repoPath, table, columns, rel); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -600,6 +612,9 @@ func importTableFile(ctx context.Context, stmt *sql.Stmt, repoPath string, table
|
||||
dec := json.NewDecoder(gz)
|
||||
dec.UseNumber()
|
||||
for {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
row := map[string]any{}
|
||||
err := dec.Decode(&row)
|
||||
if err == io.EOF {
|
||||
@ -760,6 +775,9 @@ func importEmbeddings(ctx context.Context, tx *sql.Tx, opts Options, manifests [
|
||||
}
|
||||
defer func() { _ = stmt.Close() }()
|
||||
for _, manifest := range manifests {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if !embeddingManifestMatches(opts, manifest) {
|
||||
continue
|
||||
}
|
||||
@ -768,6 +786,9 @@ func importEmbeddings(ctx context.Context, tx *sql.Tx, opts Options, manifests [
|
||||
return fmt.Errorf("embedding manifest %s/%s/%s has no files", manifest.Provider, manifest.Model, manifest.InputVersion)
|
||||
}
|
||||
for _, rel := range files {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := importEmbeddingFile(ctx, stmt, opts.RepoPath, rel); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -791,6 +812,9 @@ func importEmbeddingFile(ctx context.Context, stmt *sql.Stmt, repoPath, rel stri
|
||||
dec := json.NewDecoder(gz)
|
||||
dec.UseNumber()
|
||||
for {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
var row struct {
|
||||
MessageID string `json:"message_id"`
|
||||
Provider string `json:"provider"`
|
||||
|
||||
@ -98,6 +98,9 @@ func (s *Store) rebuildMemberFTS(ctx context.Context) error {
|
||||
defer func() { _ = stmt.Close() }()
|
||||
|
||||
for rows.Next() {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
var guildID string
|
||||
var userID string
|
||||
var username string
|
||||
|
||||
@ -623,6 +623,9 @@ func (s *Store) rebuildFTS(ctx context.Context) error {
|
||||
defer func() { _ = stmt.Close() }()
|
||||
|
||||
for rows.Next() {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
var (
|
||||
messageID string
|
||||
guildID string
|
||||
|
||||
@ -70,6 +70,35 @@ func TestUpsertMessagesBatch(t *testing.T) {
|
||||
require.Equal(t, "2", rows[0][0])
|
||||
}
|
||||
|
||||
func TestUpsertMessagesHonorsCanceledContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = s.Close() }()
|
||||
|
||||
canceled, cancel := context.WithCancel(ctx)
|
||||
cancel()
|
||||
err = s.UpsertMessages(canceled, []MessageMutation{{
|
||||
Record: MessageRecord{
|
||||
ID: "m1",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
MessageType: 0,
|
||||
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
|
||||
Content: "one",
|
||||
NormalizedContent: "one",
|
||||
RawJSON: `{"id":"m1"}`,
|
||||
},
|
||||
}})
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
|
||||
_, rows, err := s.ReadOnlyQuery(ctx, "select count(*) from messages")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "0", rows[0][0])
|
||||
}
|
||||
|
||||
func TestUpsertMessagesSkipsEventsAndEmbeddingsByDefault(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@ -298,6 +298,9 @@ func (s *Store) UpsertMessages(ctx context.Context, messages []MessageMutation)
|
||||
}
|
||||
defer rollback(tx)
|
||||
for _, message := range messages {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := upsertMessageTx(ctx, tx, message.Record, message.Options); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user