diff --git a/snapshot/snapshot.go b/snapshot/snapshot.go index 39a6cbb..ac55cd6 100644 --- a/snapshot/snapshot.go +++ b/snapshot/snapshot.go @@ -38,6 +38,7 @@ type ImportOptions struct { DeleteTables []string DeleteTable DeleteFunc Filter RowFilter + Progress func(ImportProgress) BeforeImport func(context.Context, *sql.Tx) error AfterImport func(context.Context, *sql.Tx) error } @@ -46,6 +47,16 @@ type RowFilter func(table string, row map[string]any) (bool, error) type DeleteFunc func(ctx context.Context, tx *sql.Tx, table string) error +type ImportProgress struct { + Phase string + Table string + File string + FileIndex int + FileCount int + Rows int + TotalRows int +} + type Sidecar struct { Name string `json:"name"` Path string `json:"path"` @@ -159,9 +170,11 @@ func Import(ctx context.Context, opts ImportOptions) (Manifest, error) { } } for _, table := range manifest.Tables { - if err := importTable(ctx, tx, opts.RootDir, table, opts.Filter); err != nil { + rows, err := importTable(ctx, tx, opts.RootDir, table, opts.Filter, opts.Progress) + if err != nil { return Manifest{}, err } + reportImportProgress(opts.Progress, ImportProgress{Phase: "table_done", Table: table.Name, Rows: rows, TotalRows: table.Rows}) } if opts.AfterImport != nil { if err := opts.AfterImport(ctx, tx); err != nil { @@ -268,43 +281,53 @@ func exportTable(ctx context.Context, db *sql.DB, rootDir, table string, maxShar return TableManifest{Name: table, Files: writer.files, Columns: cols, Rows: count}, nil } -func importTable(ctx context.Context, tx *sql.Tx, rootDir string, table TableManifest, filter RowFilter) error { +func importTable(ctx context.Context, tx *sql.Tx, rootDir string, table TableManifest, filter RowFilter, progress func(ImportProgress)) (int, error) { files := table.Files if len(files) == 0 && strings.TrimSpace(table.File) != "" { files = []string{table.File} } if len(files) == 0 { - return nil + return 0, nil } - for _, rel := range files { + reportImportProgress(progress, ImportProgress{Phase: "table_start", Table: table.Name, FileCount: len(files), TotalRows: table.Rows}) + totalRows := 0 + for index, rel := range files { path := filepath.Join(rootDir, filepath.FromSlash(rel)) file, err := os.Open(path) if err != nil { - return fmt.Errorf("open %s: %w", rel, err) + return totalRows, fmt.Errorf("open %s: %w", rel, err) } - if err := importJSONLGzip(ctx, tx, file, table.Name, filter); err != nil { + fileProgress := ImportProgress{Phase: "file_start", Table: table.Name, File: rel, FileIndex: index + 1, FileCount: len(files), TotalRows: table.Rows} + reportImportProgress(progress, fileProgress) + rows, err := importJSONLGzip(ctx, tx, file, table.Name, filter) + if err != nil { _ = file.Close() - return err + return totalRows, err } if err := file.Close(); err != nil { - return fmt.Errorf("close %s: %w", rel, err) + return totalRows, fmt.Errorf("close %s: %w", rel, err) } + totalRows += rows + fileProgress.Phase = "file_done" + fileProgress.Rows = rows + reportImportProgress(progress, fileProgress) } - return nil + return totalRows, nil } -func importJSONLGzip(ctx context.Context, tx *sql.Tx, reader io.Reader, table string, filter RowFilter) error { +func importJSONLGzip(ctx context.Context, tx *sql.Tx, reader io.Reader, table string, filter RowFilter) (int, error) { gz, err := gzip.NewReader(reader) if err != nil { - return fmt.Errorf("open gzip for %s: %w", table, err) + return 0, fmt.Errorf("open gzip for %s: %w", table, err) } defer gz.Close() scanner := bufio.NewScanner(gz) scanner.Buffer(make([]byte, 0, 1024*1024), 64*1024*1024) + rows := 0 for scanner.Scan() { var row map[string]any if err := json.Unmarshal(scanner.Bytes(), &row); err != nil { - return fmt.Errorf("decode %s row: %w", table, err) + return rows, fmt.Errorf("decode %s row: %w", table, err) } if len(row) == 0 { continue @@ -312,20 +335,27 @@ func importJSONLGzip(ctx context.Context, tx *sql.Tx, reader io.Reader, table st if filter != nil { keep, err := filter(table, row) if err != nil { - return fmt.Errorf("filter %s row: %w", table, err) + return rows, fmt.Errorf("filter %s row: %w", table, err) } if !keep { continue } } if err := insertRow(ctx, tx, table, row); err != nil { - return err + return rows, err } + rows++ } if err := scanner.Err(); err != nil { - return fmt.Errorf("scan %s rows: %w", table, err) + return rows, fmt.Errorf("scan %s rows: %w", table, err) + } + return rows, nil +} + +func reportImportProgress(progress func(ImportProgress), event ImportProgress) { + if progress != nil { + progress(event) } - return nil } func insertRow(ctx context.Context, tx *sql.Tx, table string, row map[string]any) error { diff --git a/snapshot/snapshot_test.go b/snapshot/snapshot_test.go index 611ea89..cea1829 100644 --- a/snapshot/snapshot_test.go +++ b/snapshot/snapshot_test.go @@ -145,6 +145,54 @@ func TestImportHooks(t *testing.T) { } } +func TestImportReportsTableAndFileProgress(t *testing.T) { + ctx := context.Background() + src, err := store.Open(ctx, store.Options{ + Path: filepath.Join(t.TempDir(), "src.db"), + Schema: `create table things(id text primary key, body text not null);`, + }) + if err != nil { + t.Fatal(err) + } + defer src.Close() + mustExec(t, src.DB(), `insert into things(id, body) values('one', 'keep')`) + mustExec(t, src.DB(), `insert into things(id, body) values('two', 'skip')`) + root := t.TempDir() + if _, err := Export(ctx, ExportOptions{DB: src.DB(), RootDir: root, Tables: []string{"things"}}); err != nil { + t.Fatal(err) + } + + dst, err := store.Open(ctx, store.Options{ + Path: filepath.Join(t.TempDir(), "dst.db"), + Schema: `create table things(id text primary key, body text not null);`, + }) + if err != nil { + t.Fatal(err) + } + defer dst.Close() + var progress []ImportProgress + if _, err := Import(ctx, ImportOptions{ + DB: dst.DB(), + RootDir: root, + Filter: func(table string, row map[string]any) (bool, error) { + return row["id"] != "two", nil + }, + Progress: func(event ImportProgress) { + progress = append(progress, event) + }, + }); err != nil { + t.Fatal(err) + } + for _, phase := range []string{"table_start", "file_start", "file_done", "table_done"} { + if !containsImportPhase(progress, phase) { + t.Fatalf("progress missing %q: %+v", phase, progress) + } + } + if got := progress[len(progress)-1]; got.Phase != "table_done" || got.Table != "things" || got.Rows != 1 || got.TotalRows != 2 { + t.Fatalf("table_done progress = %+v", got) + } +} + func TestImportLegacySingularFileManifest(t *testing.T) { ctx := context.Background() root := t.TempDir() @@ -232,6 +280,15 @@ func TestImportFilterSkipsRows(t *testing.T) { } } +func containsImportPhase(progress []ImportProgress, phase string) bool { + for _, event := range progress { + if event.Phase == phase { + return true + } + } + return false +} + func mustExec(t *testing.T, db *sql.DB, query string) { t.Helper() if _, err := db.Exec(query); err != nil {