feat(snapshot): report import table progress

This commit is contained in:
Vincent Koc 2026-05-03 23:49:58 -07:00
parent 4080014321
commit ee1c3057d0
No known key found for this signature in database
2 changed files with 103 additions and 16 deletions

View File

@ -38,6 +38,7 @@ type ImportOptions struct {
DeleteTables []string
DeleteTable DeleteFunc
Filter RowFilter
Progress func(ImportProgress)
BeforeImport func(context.Context, *sql.Tx) error
AfterImport func(context.Context, *sql.Tx) error
}
@ -46,6 +47,16 @@ type RowFilter func(table string, row map[string]any) (bool, error)
type DeleteFunc func(ctx context.Context, tx *sql.Tx, table string) error
type ImportProgress struct {
Phase string
Table string
File string
FileIndex int
FileCount int
Rows int
TotalRows int
}
type Sidecar struct {
Name string `json:"name"`
Path string `json:"path"`
@ -159,9 +170,11 @@ func Import(ctx context.Context, opts ImportOptions) (Manifest, error) {
}
}
for _, table := range manifest.Tables {
if err := importTable(ctx, tx, opts.RootDir, table, opts.Filter); err != nil {
rows, err := importTable(ctx, tx, opts.RootDir, table, opts.Filter, opts.Progress)
if err != nil {
return Manifest{}, err
}
reportImportProgress(opts.Progress, ImportProgress{Phase: "table_done", Table: table.Name, Rows: rows, TotalRows: table.Rows})
}
if opts.AfterImport != nil {
if err := opts.AfterImport(ctx, tx); err != nil {
@ -268,43 +281,53 @@ func exportTable(ctx context.Context, db *sql.DB, rootDir, table string, maxShar
return TableManifest{Name: table, Files: writer.files, Columns: cols, Rows: count}, nil
}
func importTable(ctx context.Context, tx *sql.Tx, rootDir string, table TableManifest, filter RowFilter) error {
func importTable(ctx context.Context, tx *sql.Tx, rootDir string, table TableManifest, filter RowFilter, progress func(ImportProgress)) (int, error) {
files := table.Files
if len(files) == 0 && strings.TrimSpace(table.File) != "" {
files = []string{table.File}
}
if len(files) == 0 {
return nil
return 0, nil
}
for _, rel := range files {
reportImportProgress(progress, ImportProgress{Phase: "table_start", Table: table.Name, FileCount: len(files), TotalRows: table.Rows})
totalRows := 0
for index, rel := range files {
path := filepath.Join(rootDir, filepath.FromSlash(rel))
file, err := os.Open(path)
if err != nil {
return fmt.Errorf("open %s: %w", rel, err)
return totalRows, fmt.Errorf("open %s: %w", rel, err)
}
if err := importJSONLGzip(ctx, tx, file, table.Name, filter); err != nil {
fileProgress := ImportProgress{Phase: "file_start", Table: table.Name, File: rel, FileIndex: index + 1, FileCount: len(files), TotalRows: table.Rows}
reportImportProgress(progress, fileProgress)
rows, err := importJSONLGzip(ctx, tx, file, table.Name, filter)
if err != nil {
_ = file.Close()
return err
return totalRows, err
}
if err := file.Close(); err != nil {
return fmt.Errorf("close %s: %w", rel, err)
return totalRows, fmt.Errorf("close %s: %w", rel, err)
}
totalRows += rows
fileProgress.Phase = "file_done"
fileProgress.Rows = rows
reportImportProgress(progress, fileProgress)
}
return nil
return totalRows, nil
}
func importJSONLGzip(ctx context.Context, tx *sql.Tx, reader io.Reader, table string, filter RowFilter) error {
func importJSONLGzip(ctx context.Context, tx *sql.Tx, reader io.Reader, table string, filter RowFilter) (int, error) {
gz, err := gzip.NewReader(reader)
if err != nil {
return fmt.Errorf("open gzip for %s: %w", table, err)
return 0, fmt.Errorf("open gzip for %s: %w", table, err)
}
defer gz.Close()
scanner := bufio.NewScanner(gz)
scanner.Buffer(make([]byte, 0, 1024*1024), 64*1024*1024)
rows := 0
for scanner.Scan() {
var row map[string]any
if err := json.Unmarshal(scanner.Bytes(), &row); err != nil {
return fmt.Errorf("decode %s row: %w", table, err)
return rows, fmt.Errorf("decode %s row: %w", table, err)
}
if len(row) == 0 {
continue
@ -312,20 +335,27 @@ func importJSONLGzip(ctx context.Context, tx *sql.Tx, reader io.Reader, table st
if filter != nil {
keep, err := filter(table, row)
if err != nil {
return fmt.Errorf("filter %s row: %w", table, err)
return rows, fmt.Errorf("filter %s row: %w", table, err)
}
if !keep {
continue
}
}
if err := insertRow(ctx, tx, table, row); err != nil {
return err
return rows, err
}
rows++
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("scan %s rows: %w", table, err)
return rows, fmt.Errorf("scan %s rows: %w", table, err)
}
return rows, nil
}
func reportImportProgress(progress func(ImportProgress), event ImportProgress) {
if progress != nil {
progress(event)
}
return nil
}
func insertRow(ctx context.Context, tx *sql.Tx, table string, row map[string]any) error {

View File

@ -145,6 +145,54 @@ func TestImportHooks(t *testing.T) {
}
}
func TestImportReportsTableAndFileProgress(t *testing.T) {
ctx := context.Background()
src, err := store.Open(ctx, store.Options{
Path: filepath.Join(t.TempDir(), "src.db"),
Schema: `create table things(id text primary key, body text not null);`,
})
if err != nil {
t.Fatal(err)
}
defer src.Close()
mustExec(t, src.DB(), `insert into things(id, body) values('one', 'keep')`)
mustExec(t, src.DB(), `insert into things(id, body) values('two', 'skip')`)
root := t.TempDir()
if _, err := Export(ctx, ExportOptions{DB: src.DB(), RootDir: root, Tables: []string{"things"}}); err != nil {
t.Fatal(err)
}
dst, err := store.Open(ctx, store.Options{
Path: filepath.Join(t.TempDir(), "dst.db"),
Schema: `create table things(id text primary key, body text not null);`,
})
if err != nil {
t.Fatal(err)
}
defer dst.Close()
var progress []ImportProgress
if _, err := Import(ctx, ImportOptions{
DB: dst.DB(),
RootDir: root,
Filter: func(table string, row map[string]any) (bool, error) {
return row["id"] != "two", nil
},
Progress: func(event ImportProgress) {
progress = append(progress, event)
},
}); err != nil {
t.Fatal(err)
}
for _, phase := range []string{"table_start", "file_start", "file_done", "table_done"} {
if !containsImportPhase(progress, phase) {
t.Fatalf("progress missing %q: %+v", phase, progress)
}
}
if got := progress[len(progress)-1]; got.Phase != "table_done" || got.Table != "things" || got.Rows != 1 || got.TotalRows != 2 {
t.Fatalf("table_done progress = %+v", got)
}
}
func TestImportLegacySingularFileManifest(t *testing.T) {
ctx := context.Background()
root := t.TempDir()
@ -232,6 +280,15 @@ func TestImportFilterSkipsRows(t *testing.T) {
}
}
func containsImportPhase(progress []ImportProgress, phase string) bool {
for _, event := range progress {
if event.Phase == phase {
return true
}
}
return false
}
func mustExec(t *testing.T, db *sql.DB, query string) {
t.Helper()
if _, err := db.Exec(query); err != nil {