feat(snapshot): report import table progress
This commit is contained in:
parent
4080014321
commit
ee1c3057d0
@ -38,6 +38,7 @@ type ImportOptions struct {
|
||||
DeleteTables []string
|
||||
DeleteTable DeleteFunc
|
||||
Filter RowFilter
|
||||
Progress func(ImportProgress)
|
||||
BeforeImport func(context.Context, *sql.Tx) error
|
||||
AfterImport func(context.Context, *sql.Tx) error
|
||||
}
|
||||
@ -46,6 +47,16 @@ type RowFilter func(table string, row map[string]any) (bool, error)
|
||||
|
||||
type DeleteFunc func(ctx context.Context, tx *sql.Tx, table string) error
|
||||
|
||||
type ImportProgress struct {
|
||||
Phase string
|
||||
Table string
|
||||
File string
|
||||
FileIndex int
|
||||
FileCount int
|
||||
Rows int
|
||||
TotalRows int
|
||||
}
|
||||
|
||||
type Sidecar struct {
|
||||
Name string `json:"name"`
|
||||
Path string `json:"path"`
|
||||
@ -159,9 +170,11 @@ func Import(ctx context.Context, opts ImportOptions) (Manifest, error) {
|
||||
}
|
||||
}
|
||||
for _, table := range manifest.Tables {
|
||||
if err := importTable(ctx, tx, opts.RootDir, table, opts.Filter); err != nil {
|
||||
rows, err := importTable(ctx, tx, opts.RootDir, table, opts.Filter, opts.Progress)
|
||||
if err != nil {
|
||||
return Manifest{}, err
|
||||
}
|
||||
reportImportProgress(opts.Progress, ImportProgress{Phase: "table_done", Table: table.Name, Rows: rows, TotalRows: table.Rows})
|
||||
}
|
||||
if opts.AfterImport != nil {
|
||||
if err := opts.AfterImport(ctx, tx); err != nil {
|
||||
@ -268,43 +281,53 @@ func exportTable(ctx context.Context, db *sql.DB, rootDir, table string, maxShar
|
||||
return TableManifest{Name: table, Files: writer.files, Columns: cols, Rows: count}, nil
|
||||
}
|
||||
|
||||
func importTable(ctx context.Context, tx *sql.Tx, rootDir string, table TableManifest, filter RowFilter) error {
|
||||
func importTable(ctx context.Context, tx *sql.Tx, rootDir string, table TableManifest, filter RowFilter, progress func(ImportProgress)) (int, error) {
|
||||
files := table.Files
|
||||
if len(files) == 0 && strings.TrimSpace(table.File) != "" {
|
||||
files = []string{table.File}
|
||||
}
|
||||
if len(files) == 0 {
|
||||
return nil
|
||||
return 0, nil
|
||||
}
|
||||
for _, rel := range files {
|
||||
reportImportProgress(progress, ImportProgress{Phase: "table_start", Table: table.Name, FileCount: len(files), TotalRows: table.Rows})
|
||||
totalRows := 0
|
||||
for index, rel := range files {
|
||||
path := filepath.Join(rootDir, filepath.FromSlash(rel))
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open %s: %w", rel, err)
|
||||
return totalRows, fmt.Errorf("open %s: %w", rel, err)
|
||||
}
|
||||
if err := importJSONLGzip(ctx, tx, file, table.Name, filter); err != nil {
|
||||
fileProgress := ImportProgress{Phase: "file_start", Table: table.Name, File: rel, FileIndex: index + 1, FileCount: len(files), TotalRows: table.Rows}
|
||||
reportImportProgress(progress, fileProgress)
|
||||
rows, err := importJSONLGzip(ctx, tx, file, table.Name, filter)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
return totalRows, err
|
||||
}
|
||||
if err := file.Close(); err != nil {
|
||||
return fmt.Errorf("close %s: %w", rel, err)
|
||||
return totalRows, fmt.Errorf("close %s: %w", rel, err)
|
||||
}
|
||||
totalRows += rows
|
||||
fileProgress.Phase = "file_done"
|
||||
fileProgress.Rows = rows
|
||||
reportImportProgress(progress, fileProgress)
|
||||
}
|
||||
return nil
|
||||
return totalRows, nil
|
||||
}
|
||||
|
||||
func importJSONLGzip(ctx context.Context, tx *sql.Tx, reader io.Reader, table string, filter RowFilter) error {
|
||||
func importJSONLGzip(ctx context.Context, tx *sql.Tx, reader io.Reader, table string, filter RowFilter) (int, error) {
|
||||
gz, err := gzip.NewReader(reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open gzip for %s: %w", table, err)
|
||||
return 0, fmt.Errorf("open gzip for %s: %w", table, err)
|
||||
}
|
||||
defer gz.Close()
|
||||
scanner := bufio.NewScanner(gz)
|
||||
scanner.Buffer(make([]byte, 0, 1024*1024), 64*1024*1024)
|
||||
rows := 0
|
||||
for scanner.Scan() {
|
||||
var row map[string]any
|
||||
if err := json.Unmarshal(scanner.Bytes(), &row); err != nil {
|
||||
return fmt.Errorf("decode %s row: %w", table, err)
|
||||
return rows, fmt.Errorf("decode %s row: %w", table, err)
|
||||
}
|
||||
if len(row) == 0 {
|
||||
continue
|
||||
@ -312,20 +335,27 @@ func importJSONLGzip(ctx context.Context, tx *sql.Tx, reader io.Reader, table st
|
||||
if filter != nil {
|
||||
keep, err := filter(table, row)
|
||||
if err != nil {
|
||||
return fmt.Errorf("filter %s row: %w", table, err)
|
||||
return rows, fmt.Errorf("filter %s row: %w", table, err)
|
||||
}
|
||||
if !keep {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if err := insertRow(ctx, tx, table, row); err != nil {
|
||||
return err
|
||||
return rows, err
|
||||
}
|
||||
rows++
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return fmt.Errorf("scan %s rows: %w", table, err)
|
||||
return rows, fmt.Errorf("scan %s rows: %w", table, err)
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
func reportImportProgress(progress func(ImportProgress), event ImportProgress) {
|
||||
if progress != nil {
|
||||
progress(event)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func insertRow(ctx context.Context, tx *sql.Tx, table string, row map[string]any) error {
|
||||
|
||||
@ -145,6 +145,54 @@ func TestImportHooks(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportReportsTableAndFileProgress(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
src, err := store.Open(ctx, store.Options{
|
||||
Path: filepath.Join(t.TempDir(), "src.db"),
|
||||
Schema: `create table things(id text primary key, body text not null);`,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer src.Close()
|
||||
mustExec(t, src.DB(), `insert into things(id, body) values('one', 'keep')`)
|
||||
mustExec(t, src.DB(), `insert into things(id, body) values('two', 'skip')`)
|
||||
root := t.TempDir()
|
||||
if _, err := Export(ctx, ExportOptions{DB: src.DB(), RootDir: root, Tables: []string{"things"}}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dst, err := store.Open(ctx, store.Options{
|
||||
Path: filepath.Join(t.TempDir(), "dst.db"),
|
||||
Schema: `create table things(id text primary key, body text not null);`,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer dst.Close()
|
||||
var progress []ImportProgress
|
||||
if _, err := Import(ctx, ImportOptions{
|
||||
DB: dst.DB(),
|
||||
RootDir: root,
|
||||
Filter: func(table string, row map[string]any) (bool, error) {
|
||||
return row["id"] != "two", nil
|
||||
},
|
||||
Progress: func(event ImportProgress) {
|
||||
progress = append(progress, event)
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, phase := range []string{"table_start", "file_start", "file_done", "table_done"} {
|
||||
if !containsImportPhase(progress, phase) {
|
||||
t.Fatalf("progress missing %q: %+v", phase, progress)
|
||||
}
|
||||
}
|
||||
if got := progress[len(progress)-1]; got.Phase != "table_done" || got.Table != "things" || got.Rows != 1 || got.TotalRows != 2 {
|
||||
t.Fatalf("table_done progress = %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportLegacySingularFileManifest(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
root := t.TempDir()
|
||||
@ -232,6 +280,15 @@ func TestImportFilterSkipsRows(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func containsImportPhase(progress []ImportProgress, phase string) bool {
|
||||
for _, event := range progress {
|
||||
if event.Phase == phase {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func mustExec(t *testing.T, db *sql.DB, query string) {
|
||||
t.Helper()
|
||||
if _, err := db.Exec(query); err != nil {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user