From 413d481c3e22bc9be44ad54c20b894e289d0db85 Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Wed, 22 Apr 2026 17:14:20 +0100 Subject: [PATCH] feat: back up embeddings in git snapshots --- CHANGELOG.md | 1 + README.md | 9 + internal/cli/share_commands.go | 22 +++ internal/share/share.go | 304 ++++++++++++++++++++++++++++++--- internal/share/share_test.go | 55 ++++++ 5 files changed, 371 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a06830..dd61d31 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ All notable changes to `discrawl` will be documented in this file. - semantic message search now ranks across the full compatible local vector set instead of only the newest candidate window. (#36) Thanks @GaosCode. - hybrid message search now fuses FTS with local semantic vectors while avoiding embedding-provider calls when no local vectors exist. (#37) Thanks @GaosCode. - docs now cover semantic and hybrid search setup, embedding privacy, Git snapshot behavior, and local vector rebuilds. (#39) Thanks @GaosCode. +- Git snapshot publishing can now opt in to backing up generated embedding vectors with `--with-embeddings` while still keeping embedding queue state local. ## 0.3.0 - 2026-04-21 diff --git a/README.md b/README.md index 2e27c6b..46c3b6c 100644 --- a/README.md +++ b/README.md @@ -377,6 +377,15 @@ Hybrid mode is supported too: keep normal Discord credentials configured and set Git snapshots publish archive tables only. Embedding queue state and generated vectors stay local to each machine. Git-only readers can use FTS immediately. To use semantic or hybrid search with semantic recall, configure a local embedding provider and run `discrawl embed --rebuild`. Hybrid search falls back to FTS when no local vectors exist. +If you want to back up generated vectors too, publish them explicitly: + +```bash +discrawl publish --with-embeddings --push +discrawl update --with-embeddings +``` + +`--with-embeddings` exports and imports stored vectors for the configured `[search.embeddings]` provider/model/input version. It never exports `embedding_jobs`. + The Docker smoke test installs `discrawl` in a clean Go container, subscribes to a Git snapshot repo, then checks `search`, `messages`, `sql`, and `report`: ```bash diff --git a/internal/cli/share_commands.go b/internal/cli/share_commands.go index 52bf2d3..cbe2037 100644 --- a/internal/cli/share_commands.go +++ b/internal/cli/share_commands.go @@ -24,6 +24,7 @@ func (r *runtime) runPublish(args []string) error { readmePath := fs.String("readme", "", "") noCommit := fs.Bool("no-commit", false, "") push := fs.Bool("push", false, "") + withEmbeddings := fs.Bool("with-embeddings", false, "") if err := fs.Parse(args); err != nil { return usageErr(err) } @@ -34,6 +35,9 @@ func (r *runtime) runPublish(args []string) error { if err != nil { return err } + if *withEmbeddings { + applyEmbeddingShareOptions(&opts, r.cfg) + } manifest, err := share.Export(r.ctx, r.store, opts) if err != nil { return err @@ -75,6 +79,7 @@ func (r *runtime) runPublish(args []string) error { "remote": opts.Remote, "generated_at": manifest.GeneratedAt, "tables": manifest.Tables, + "embeddings": manifest.Embeddings, "readme": *readmePath, "committed": committed, "pushed": *push, @@ -89,6 +94,7 @@ func (r *runtime) runSubscribe(args []string) error { staleAfter := fs.String("stale-after", "15m", "") noAutoUpdate := fs.Bool("no-auto-update", false, "") noImport := fs.Bool("no-import", false, "") + withEmbeddings := fs.Bool("with-embeddings", false, "") if err := fs.Parse(args); err != nil { return usageErr(err) } @@ -134,6 +140,9 @@ func (r *runtime) runSubscribe(args []string) error { return configErr(err) } opts := share.Options{RepoPath: expandedRepo, Remote: cfg.Share.Remote, Branch: cfg.Share.Branch} + if *withEmbeddings { + applyEmbeddingShareOptions(&opts, cfg) + } if err := share.Pull(r.ctx, opts); err != nil { return err } @@ -147,6 +156,7 @@ func (r *runtime) runSubscribe(args []string) error { "remote": opts.Remote, "generated_at": manifest.GeneratedAt, "tables": manifest.Tables, + "embeddings": manifest.Embeddings, "imported": imported, }) } @@ -157,6 +167,7 @@ func (r *runtime) runUpdate(args []string) error { repoPath := fs.String("repo", r.cfg.Share.RepoPath, "") remote := fs.String("remote", r.cfg.Share.Remote, "") branch := fs.String("branch", r.cfg.Share.Branch, "") + withEmbeddings := fs.Bool("with-embeddings", false, "") if err := fs.Parse(args); err != nil { return usageErr(err) } @@ -167,6 +178,9 @@ func (r *runtime) runUpdate(args []string) error { if err != nil { return err } + if *withEmbeddings { + applyEmbeddingShareOptions(&opts, r.cfg) + } if err := share.Pull(r.ctx, opts); err != nil { return err } @@ -179,6 +193,7 @@ func (r *runtime) runUpdate(args []string) error { "remote": opts.Remote, "generated_at": manifest.GeneratedAt, "tables": manifest.Tables, + "embeddings": manifest.Embeddings, "imported": imported, }) } @@ -197,6 +212,13 @@ func shareOptionsFromFlags(repoPath, remote, branch string) (share.Options, erro return share.Options{RepoPath: expandedRepo, Remote: remote, Branch: branch}, nil } +func applyEmbeddingShareOptions(opts *share.Options, cfg config.Config) { + opts.IncludeEmbeddings = true + opts.EmbeddingProvider = cfg.Search.Embeddings.Provider + opts.EmbeddingModel = cfg.Search.Embeddings.Model + opts.EmbeddingInputVersion = store.EmbeddingInputVersion +} + func loadConfigOrDefault(path string) (config.Config, error) { cfg, err := config.Load(path) if err == nil { diff --git a/internal/share/share.go b/internal/share/share.go index 83e6adc..df2ee6f 100644 --- a/internal/share/share.go +++ b/internal/share/share.go @@ -4,6 +4,7 @@ import ( "compress/gzip" "context" "database/sql" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -42,16 +43,21 @@ var SnapshotTables = []string{ } type Options struct { - RepoPath string - Remote string - Branch string + RepoPath string + Remote string + Branch string + IncludeEmbeddings bool + EmbeddingProvider string + EmbeddingModel string + EmbeddingInputVersion string } type Manifest struct { - Version int `json:"version"` - GeneratedAt time.Time `json:"generated_at"` - Tables []TableManifest `json:"tables"` - Files map[string]string `json:"files,omitempty"` + Version int `json:"version"` + GeneratedAt time.Time `json:"generated_at"` + Tables []TableManifest `json:"tables"` + Embeddings []EmbeddingManifest `json:"embeddings,omitempty"` + Files map[string]string `json:"files,omitempty"` } type TableManifest struct { @@ -62,6 +68,15 @@ type TableManifest struct { Rows int `json:"rows"` } +type EmbeddingManifest struct { + Provider string `json:"provider"` + Model string `json:"model"` + InputVersion string `json:"input_version"` + Files []string `json:"files"` + Columns []string `json:"columns"` + Rows int `json:"rows"` +} + func EnsureRepo(ctx context.Context, opts Options) error { if strings.TrimSpace(opts.RepoPath) == "" { return fmt.Errorf("share repo path is empty") @@ -163,11 +178,10 @@ func Export(ctx context.Context, s *store.Store, opts Options) (Manifest, error) if err := EnsureRepo(ctx, opts); err != nil { return Manifest{}, err } - dataDir := filepath.Join(opts.RepoPath, "tables") - if err := os.RemoveAll(dataDir); err != nil { + if err := os.RemoveAll(filepath.Join(opts.RepoPath, "tables")); err != nil { return Manifest{}, fmt.Errorf("reset tables dir: %w", err) } - if err := os.MkdirAll(dataDir, 0o755); err != nil { + if err := os.MkdirAll(filepath.Join(opts.RepoPath, "tables"), 0o755); err != nil { return Manifest{}, fmt.Errorf("mkdir tables dir: %w", err) } manifest := Manifest{ @@ -175,13 +189,25 @@ func Export(ctx context.Context, s *store.Store, opts Options) (Manifest, error) GeneratedAt: time.Now().UTC(), Files: map[string]string{"manifest": ManifestName}, } + if !opts.IncludeEmbeddings { + if previous, err := ReadManifest(opts.RepoPath); err == nil { + manifest.Embeddings = previous.Embeddings + } + } for _, table := range SnapshotTables { - entry, err := exportTable(ctx, s.DB(), dataDir, table) + entry, err := exportTable(ctx, s.DB(), opts.RepoPath, table) if err != nil { return Manifest{}, err } manifest.Tables = append(manifest.Tables, entry) } + if opts.IncludeEmbeddings { + entry, err := exportEmbeddings(ctx, s.DB(), opts) + if err != nil { + return Manifest{}, err + } + manifest.Embeddings = []EmbeddingManifest{entry} + } body, err := json.MarshalIndent(manifest, "", " ") if err != nil { return Manifest{}, err @@ -234,6 +260,11 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error) return Manifest{}, err } } + if opts.IncludeEmbeddings { + if err := importEmbeddings(ctx, tx, opts, manifest.Embeddings); err != nil { + return Manifest{}, err + } + } if err := tx.Commit(); err != nil { return Manifest{}, err } @@ -283,6 +314,11 @@ func ImportIfChanged(ctx context.Context, s *store.Store, opts Options) (Manifes return Manifest{}, false, err } if ManifestAlreadyImported(ctx, s, manifest) { + if opts.IncludeEmbeddings { + if err := ImportEmbeddings(ctx, s, opts, manifest); err != nil { + return Manifest{}, false, err + } + } if err := MarkImported(ctx, s, manifest); err != nil { return Manifest{}, false, err } @@ -295,6 +331,27 @@ func ImportIfChanged(ctx context.Context, s *store.Store, opts Options) (Manifes return imported, true, nil } +func ImportEmbeddings(ctx context.Context, s *store.Store, opts Options, manifest Manifest) error { + tx, err := s.DB().BeginTx(ctx, nil) + if err != nil { + return err + } + committed := false + defer func() { + if !committed { + _ = tx.Rollback() + } + }() + if err := importEmbeddings(ctx, tx, opts, manifest.Embeddings); err != nil { + return err + } + if err := tx.Commit(); err != nil { + return err + } + committed = true + return nil +} + func ManifestAlreadyImported(ctx context.Context, s *store.Store, manifest Manifest) bool { if manifest.GeneratedAt.IsZero() { return false @@ -353,7 +410,7 @@ func NeedsImport(ctx context.Context, s *store.Store, staleAfter time.Duration) return time.Since(t) >= staleAfter } -func exportTable(ctx context.Context, db *sql.DB, dataDir, table string) (TableManifest, error) { +func exportTable(ctx context.Context, db *sql.DB, repoPath, table string) (TableManifest, error) { query := "select * from " + table rows, err := db.QueryContext(ctx, query) if err != nil { @@ -364,11 +421,11 @@ func exportTable(ctx context.Context, db *sql.DB, dataDir, table string) (TableM if err != nil { return TableManifest{}, fmt.Errorf("columns %s: %w", table, err) } - tableDir := filepath.Join(dataDir, table) + tableDir := filepath.Join(repoPath, "tables", table) if err := os.MkdirAll(tableDir, 0o755); err != nil { return TableManifest{}, fmt.Errorf("mkdir %s: %w", table, err) } - writer := tableShardWriter{dataDir: dataDir, table: table} + writer := tableShardWriter{rootDir: repoPath, relDir: filepath.ToSlash(filepath.Join("tables", table)), label: table} if err := writer.open(); err != nil { return TableManifest{}, err } @@ -415,6 +472,95 @@ func exportTable(ctx context.Context, db *sql.DB, dataDir, table string) (TableM return TableManifest{Name: table, Files: writer.files, Columns: columns, Rows: count}, nil } +func exportEmbeddings(ctx context.Context, db *sql.DB, opts Options) (EmbeddingManifest, error) { + provider := strings.ToLower(strings.TrimSpace(opts.EmbeddingProvider)) + model := strings.TrimSpace(opts.EmbeddingModel) + inputVersion := strings.TrimSpace(opts.EmbeddingInputVersion) + if inputVersion == "" { + inputVersion = store.EmbeddingInputVersion + } + if provider == "" || model == "" { + return EmbeddingManifest{}, fmt.Errorf("embedding provider and model are required") + } + relDir := filepath.ToSlash(filepath.Join("embeddings", safePathSegment(provider), safePathSegment(model), safePathSegment(inputVersion))) + if err := os.RemoveAll(filepath.Join(opts.RepoPath, "embeddings")); err != nil { + return EmbeddingManifest{}, fmt.Errorf("reset embeddings dir: %w", err) + } + if err := os.MkdirAll(filepath.Join(opts.RepoPath, filepath.FromSlash(relDir)), 0o755); err != nil { + return EmbeddingManifest{}, fmt.Errorf("mkdir %s: %w", relDir, err) + } + rows, err := db.QueryContext(ctx, ` + select message_id, provider, model, input_version, dimensions, embedding_blob, embedded_at + from message_embeddings + where provider = ? and model = ? and input_version = ? + order by message_id + `, provider, model, inputVersion) + if err != nil { + return EmbeddingManifest{}, fmt.Errorf("query message_embeddings: %w", err) + } + defer func() { _ = rows.Close() }() + writer := tableShardWriter{rootDir: opts.RepoPath, relDir: relDir, label: "message_embeddings"} + if err := writer.open(); err != nil { + return EmbeddingManifest{}, err + } + defer func() { _ = writer.close() }() + columns := []string{"message_id", "provider", "model", "input_version", "dimensions", "embedding_blob", "embedded_at"} + count := 0 + for rows.Next() { + var ( + messageID string + rowProv string + rowModel string + rowInput string + dimensions int + blob []byte + embeddedAt string + ) + if err := rows.Scan(&messageID, &rowProv, &rowModel, &rowInput, &dimensions, &blob, &embeddedAt); err != nil { + return EmbeddingManifest{}, fmt.Errorf("scan message_embeddings: %w", err) + } + body, err := json.Marshal(map[string]any{ + "message_id": messageID, + "provider": rowProv, + "model": rowModel, + "input_version": rowInput, + "dimensions": dimensions, + "embedding_blob": base64.StdEncoding.EncodeToString(blob), + "embedded_at": embeddedAt, + }) + if err != nil { + return EmbeddingManifest{}, fmt.Errorf("marshal message_embeddings row: %w", err) + } + if err := writer.rotateIfNeeded(); err != nil { + return EmbeddingManifest{}, err + } + if _, err := writer.Write(body); err != nil { + return EmbeddingManifest{}, fmt.Errorf("write message_embeddings row: %w", err) + } + if _, err := writer.Write([]byte{'\n'}); err != nil { + return EmbeddingManifest{}, fmt.Errorf("write message_embeddings newline: %w", err) + } + count++ + if err := writer.finishRow(); err != nil { + return EmbeddingManifest{}, err + } + } + if err := rows.Err(); err != nil { + return EmbeddingManifest{}, fmt.Errorf("iterate message_embeddings: %w", err) + } + if err := writer.close(); err != nil { + return EmbeddingManifest{}, err + } + return EmbeddingManifest{ + Provider: provider, + Model: model, + InputVersion: inputVersion, + Files: writer.files, + Columns: columns, + Rows: count, + }, nil +} + func importTable(ctx context.Context, tx *sql.Tx, repoPath string, table TableManifest) error { files := table.Files if len(files) == 0 && strings.TrimSpace(table.File) != "" { @@ -470,9 +616,104 @@ func importTableFile(ctx context.Context, stmt *sql.Stmt, repoPath string, table return nil } +func importEmbeddings(ctx context.Context, tx *sql.Tx, opts Options, manifests []EmbeddingManifest) error { + if len(manifests) == 0 { + return nil + } + stmt, err := tx.PrepareContext(ctx, ` + insert into message_embeddings( + message_id, provider, model, input_version, dimensions, embedding_blob, embedded_at + ) values(?, ?, ?, ?, ?, ?, ?) + on conflict(message_id, provider, model, input_version) do update set + dimensions = excluded.dimensions, + embedding_blob = excluded.embedding_blob, + embedded_at = excluded.embedded_at + `) + if err != nil { + return fmt.Errorf("prepare import message_embeddings: %w", err) + } + defer func() { _ = stmt.Close() }() + for _, manifest := range manifests { + if !embeddingManifestMatches(opts, manifest) { + continue + } + files := manifest.Files + if len(files) == 0 { + return fmt.Errorf("embedding manifest %s/%s/%s has no files", manifest.Provider, manifest.Model, manifest.InputVersion) + } + for _, rel := range files { + if err := importEmbeddingFile(ctx, stmt, opts.RepoPath, rel); err != nil { + return err + } + } + } + return nil +} + +func importEmbeddingFile(ctx context.Context, stmt *sql.Stmt, repoPath, rel string) error { + path := filepath.Join(repoPath, filepath.FromSlash(rel)) + file, err := os.Open(path) + if err != nil { + return fmt.Errorf("open %s: %w", rel, err) + } + defer func() { _ = file.Close() }() + gz, err := gzip.NewReader(file) + if err != nil { + return fmt.Errorf("read gzip %s: %w", rel, err) + } + defer func() { _ = gz.Close() }() + dec := json.NewDecoder(gz) + dec.UseNumber() + for { + var row struct { + MessageID string `json:"message_id"` + Provider string `json:"provider"` + Model string `json:"model"` + InputVersion string `json:"input_version"` + Dimensions json.Number `json:"dimensions"` + EmbeddingBlob string `json:"embedding_blob"` + EmbeddedAt string `json:"embedded_at"` + } + err := dec.Decode(&row) + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return fmt.Errorf("decode %s: %w", rel, err) + } + dimensions, err := strconv.Atoi(row.Dimensions.String()) + if err != nil { + return fmt.Errorf("decode dimensions in %s: %w", rel, err) + } + blob, err := base64.StdEncoding.DecodeString(row.EmbeddingBlob) + if err != nil { + return fmt.Errorf("decode embedding blob in %s: %w", rel, err) + } + if _, err := stmt.ExecContext(ctx, row.MessageID, row.Provider, row.Model, row.InputVersion, dimensions, blob, row.EmbeddedAt); err != nil { + return fmt.Errorf("insert message_embeddings: %w", err) + } + } + return nil +} + +func embeddingManifestMatches(opts Options, manifest EmbeddingManifest) bool { + if strings.TrimSpace(opts.EmbeddingProvider) != "" && manifest.Provider != strings.ToLower(strings.TrimSpace(opts.EmbeddingProvider)) { + return false + } + if strings.TrimSpace(opts.EmbeddingModel) != "" && manifest.Model != strings.TrimSpace(opts.EmbeddingModel) { + return false + } + inputVersion := strings.TrimSpace(opts.EmbeddingInputVersion) + if inputVersion == "" { + inputVersion = store.EmbeddingInputVersion + } + return manifest.InputVersion == inputVersion +} + type tableShardWriter struct { - dataDir string - table string + rootDir string + relDir string + label string nextShard int rowsInShard int files []string @@ -482,8 +723,8 @@ type tableShardWriter struct { } func (w *tableShardWriter) open() error { - rel := filepath.ToSlash(filepath.Join("tables", w.table, fmt.Sprintf("%06d.jsonl.gz", w.nextShard))) - path := filepath.Join(w.dataDir, w.table, fmt.Sprintf("%06d.jsonl.gz", w.nextShard)) + rel := filepath.ToSlash(filepath.Join(w.relDir, fmt.Sprintf("%06d.jsonl.gz", w.nextShard))) + path := filepath.Join(w.rootDir, filepath.FromSlash(rel)) file, err := os.Create(path) if err != nil { return fmt.Errorf("create %s: %w", rel, err) @@ -517,7 +758,7 @@ func (w *tableShardWriter) finishRow() error { return nil } if err := w.gz.Flush(); err != nil { - return fmt.Errorf("flush %s shard: %w", w.table, err) + return fmt.Errorf("flush %s shard: %w", w.label, err) } return nil } @@ -537,7 +778,7 @@ func (w *tableShardWriter) close() error { w.file = nil } if closeErr != nil { - return fmt.Errorf("close %s shard: %w", w.table, closeErr) + return fmt.Errorf("close %s shard: %w", w.label, closeErr) } return nil } @@ -591,6 +832,29 @@ func quoteIdent(s string) string { return `"` + strings.ReplaceAll(s, `"`, `""`) + `"` } +func safePathSegment(s string) string { + s = strings.TrimSpace(s) + if s == "" { + return "_" + } + var b strings.Builder + for _, r := range s { + switch { + case r >= 'a' && r <= 'z': + b.WriteRune(r) + case r >= 'A' && r <= 'Z': + b.WriteRune(r) + case r >= '0' && r <= '9': + b.WriteRune(r) + case r == '-' || r == '_' || r == '.': + b.WriteRune(r) + default: + b.WriteRune('_') + } + } + return b.String() +} + func run(ctx context.Context, dir, name string, args ...string) error { out, err := output(ctx, dir, name, args...) if err != nil { diff --git a/internal/share/share_test.go b/internal/share/share_test.go index d441caf..674f01c 100644 --- a/internal/share/share_test.go +++ b/internal/share/share_test.go @@ -75,6 +75,7 @@ func TestSnapshotExcludesLocalEmbeddingState(t *testing.T) { require.NoError(t, err) require.NotContains(t, tableNames(manifest), "embedding_jobs") require.NotContains(t, tableNames(manifest), "message_embeddings") + require.Empty(t, manifest.Embeddings) dst, err := store.Open(ctx, filepath.Join(t.TempDir(), "dst.db")) require.NoError(t, err) @@ -94,6 +95,60 @@ func TestSnapshotExcludesLocalEmbeddingState(t *testing.T) { require.Equal(t, "pending", state) } +func TestExportImportEmbeddingsOptIn(t *testing.T) { + ctx := context.Background() + src := seedStore(t, filepath.Join(t.TempDir(), "src.db")) + defer func() { _ = src.Close() }() + + vector := []float32{1, 0.5} + blob, err := store.EncodeEmbeddingVector(vector) + require.NoError(t, err) + embeddedAt := time.Now().UTC().Format(time.RFC3339Nano) + _, err = src.DB().ExecContext(ctx, ` + insert into message_embeddings( + message_id, provider, model, input_version, dimensions, embedding_blob, embedded_at + ) values ('m1', 'openai', 'text-embedding-3-small', ?, 2, ?, ?) + `, store.EmbeddingInputVersion, blob, embeddedAt) + require.NoError(t, err) + + repo := filepath.Join(t.TempDir(), "share") + opts := Options{ + RepoPath: repo, + Branch: "main", + IncludeEmbeddings: true, + EmbeddingProvider: "openai", + EmbeddingModel: "text-embedding-3-small", + EmbeddingInputVersion: store.EmbeddingInputVersion, + } + manifest, err := Export(ctx, src, opts) + require.NoError(t, err) + require.Len(t, manifest.Embeddings, 1) + require.Equal(t, 1, manifest.Embeddings[0].Rows) + require.NotEmpty(t, manifest.Embeddings[0].Files) + require.FileExists(t, filepath.Join(repo, filepath.FromSlash(manifest.Embeddings[0].Files[0]))) + + dst, err := store.Open(ctx, filepath.Join(t.TempDir(), "dst.db")) + require.NoError(t, err) + defer func() { _ = dst.Close() }() + _, err = Import(ctx, dst, opts) + require.NoError(t, err) + + var gotBlob []byte + var gotDimensions int + require.NoError(t, dst.DB().QueryRowContext(ctx, ` + select dimensions, embedding_blob + from message_embeddings + where message_id = 'm1' + and provider = 'openai' + and model = 'text-embedding-3-small' + and input_version = ? + `, store.EmbeddingInputVersion).Scan(&gotDimensions, &gotBlob)) + require.Equal(t, 2, gotDimensions) + gotVector, err := store.DecodeEmbeddingVector(gotBlob) + require.NoError(t, err) + require.Equal(t, vector, gotVector) +} + func TestImportIfChangedSkipsSameManifest(t *testing.T) { ctx := context.Background() src := seedStore(t, filepath.Join(t.TempDir(), "src.db"))