feat: back up embeddings in git snapshots

This commit is contained in:
Peter Steinberger 2026-04-22 17:14:20 +01:00
parent 4437514537
commit 413d481c3e
No known key found for this signature in database
5 changed files with 371 additions and 20 deletions

View File

@ -13,6 +13,7 @@ All notable changes to `discrawl` will be documented in this file.
- semantic message search now ranks across the full compatible local vector set instead of only the newest candidate window. (#36) Thanks @GaosCode.
- hybrid message search now fuses FTS with local semantic vectors while avoiding embedding-provider calls when no local vectors exist. (#37) Thanks @GaosCode.
- docs now cover semantic and hybrid search setup, embedding privacy, Git snapshot behavior, and local vector rebuilds. (#39) Thanks @GaosCode.
- Git snapshot publishing can now opt in to backing up generated embedding vectors with `--with-embeddings` while still keeping embedding queue state local.
## 0.3.0 - 2026-04-21

View File

@ -377,6 +377,15 @@ Hybrid mode is supported too: keep normal Discord credentials configured and set
Git snapshots publish archive tables only. Embedding queue state and generated vectors stay local to each machine. Git-only readers can use FTS immediately. To use semantic or hybrid search with semantic recall, configure a local embedding provider and run `discrawl embed --rebuild`. Hybrid search falls back to FTS when no local vectors exist.
If you want to back up generated vectors too, publish them explicitly:
```bash
discrawl publish --with-embeddings --push
discrawl update --with-embeddings
```
`--with-embeddings` exports and imports stored vectors for the configured `[search.embeddings]` provider/model/input version. It never exports `embedding_jobs`.
The Docker smoke test installs `discrawl` in a clean Go container, subscribes to a Git snapshot repo, then checks `search`, `messages`, `sql`, and `report`:
```bash

View File

@ -24,6 +24,7 @@ func (r *runtime) runPublish(args []string) error {
readmePath := fs.String("readme", "", "")
noCommit := fs.Bool("no-commit", false, "")
push := fs.Bool("push", false, "")
withEmbeddings := fs.Bool("with-embeddings", false, "")
if err := fs.Parse(args); err != nil {
return usageErr(err)
}
@ -34,6 +35,9 @@ func (r *runtime) runPublish(args []string) error {
if err != nil {
return err
}
if *withEmbeddings {
applyEmbeddingShareOptions(&opts, r.cfg)
}
manifest, err := share.Export(r.ctx, r.store, opts)
if err != nil {
return err
@ -75,6 +79,7 @@ func (r *runtime) runPublish(args []string) error {
"remote": opts.Remote,
"generated_at": manifest.GeneratedAt,
"tables": manifest.Tables,
"embeddings": manifest.Embeddings,
"readme": *readmePath,
"committed": committed,
"pushed": *push,
@ -89,6 +94,7 @@ func (r *runtime) runSubscribe(args []string) error {
staleAfter := fs.String("stale-after", "15m", "")
noAutoUpdate := fs.Bool("no-auto-update", false, "")
noImport := fs.Bool("no-import", false, "")
withEmbeddings := fs.Bool("with-embeddings", false, "")
if err := fs.Parse(args); err != nil {
return usageErr(err)
}
@ -134,6 +140,9 @@ func (r *runtime) runSubscribe(args []string) error {
return configErr(err)
}
opts := share.Options{RepoPath: expandedRepo, Remote: cfg.Share.Remote, Branch: cfg.Share.Branch}
if *withEmbeddings {
applyEmbeddingShareOptions(&opts, cfg)
}
if err := share.Pull(r.ctx, opts); err != nil {
return err
}
@ -147,6 +156,7 @@ func (r *runtime) runSubscribe(args []string) error {
"remote": opts.Remote,
"generated_at": manifest.GeneratedAt,
"tables": manifest.Tables,
"embeddings": manifest.Embeddings,
"imported": imported,
})
}
@ -157,6 +167,7 @@ func (r *runtime) runUpdate(args []string) error {
repoPath := fs.String("repo", r.cfg.Share.RepoPath, "")
remote := fs.String("remote", r.cfg.Share.Remote, "")
branch := fs.String("branch", r.cfg.Share.Branch, "")
withEmbeddings := fs.Bool("with-embeddings", false, "")
if err := fs.Parse(args); err != nil {
return usageErr(err)
}
@ -167,6 +178,9 @@ func (r *runtime) runUpdate(args []string) error {
if err != nil {
return err
}
if *withEmbeddings {
applyEmbeddingShareOptions(&opts, r.cfg)
}
if err := share.Pull(r.ctx, opts); err != nil {
return err
}
@ -179,6 +193,7 @@ func (r *runtime) runUpdate(args []string) error {
"remote": opts.Remote,
"generated_at": manifest.GeneratedAt,
"tables": manifest.Tables,
"embeddings": manifest.Embeddings,
"imported": imported,
})
}
@ -197,6 +212,13 @@ func shareOptionsFromFlags(repoPath, remote, branch string) (share.Options, erro
return share.Options{RepoPath: expandedRepo, Remote: remote, Branch: branch}, nil
}
func applyEmbeddingShareOptions(opts *share.Options, cfg config.Config) {
opts.IncludeEmbeddings = true
opts.EmbeddingProvider = cfg.Search.Embeddings.Provider
opts.EmbeddingModel = cfg.Search.Embeddings.Model
opts.EmbeddingInputVersion = store.EmbeddingInputVersion
}
func loadConfigOrDefault(path string) (config.Config, error) {
cfg, err := config.Load(path)
if err == nil {

View File

@ -4,6 +4,7 @@ import (
"compress/gzip"
"context"
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
@ -42,16 +43,21 @@ var SnapshotTables = []string{
}
type Options struct {
RepoPath string
Remote string
Branch string
RepoPath string
Remote string
Branch string
IncludeEmbeddings bool
EmbeddingProvider string
EmbeddingModel string
EmbeddingInputVersion string
}
type Manifest struct {
Version int `json:"version"`
GeneratedAt time.Time `json:"generated_at"`
Tables []TableManifest `json:"tables"`
Files map[string]string `json:"files,omitempty"`
Version int `json:"version"`
GeneratedAt time.Time `json:"generated_at"`
Tables []TableManifest `json:"tables"`
Embeddings []EmbeddingManifest `json:"embeddings,omitempty"`
Files map[string]string `json:"files,omitempty"`
}
type TableManifest struct {
@ -62,6 +68,15 @@ type TableManifest struct {
Rows int `json:"rows"`
}
type EmbeddingManifest struct {
Provider string `json:"provider"`
Model string `json:"model"`
InputVersion string `json:"input_version"`
Files []string `json:"files"`
Columns []string `json:"columns"`
Rows int `json:"rows"`
}
func EnsureRepo(ctx context.Context, opts Options) error {
if strings.TrimSpace(opts.RepoPath) == "" {
return fmt.Errorf("share repo path is empty")
@ -163,11 +178,10 @@ func Export(ctx context.Context, s *store.Store, opts Options) (Manifest, error)
if err := EnsureRepo(ctx, opts); err != nil {
return Manifest{}, err
}
dataDir := filepath.Join(opts.RepoPath, "tables")
if err := os.RemoveAll(dataDir); err != nil {
if err := os.RemoveAll(filepath.Join(opts.RepoPath, "tables")); err != nil {
return Manifest{}, fmt.Errorf("reset tables dir: %w", err)
}
if err := os.MkdirAll(dataDir, 0o755); err != nil {
if err := os.MkdirAll(filepath.Join(opts.RepoPath, "tables"), 0o755); err != nil {
return Manifest{}, fmt.Errorf("mkdir tables dir: %w", err)
}
manifest := Manifest{
@ -175,13 +189,25 @@ func Export(ctx context.Context, s *store.Store, opts Options) (Manifest, error)
GeneratedAt: time.Now().UTC(),
Files: map[string]string{"manifest": ManifestName},
}
if !opts.IncludeEmbeddings {
if previous, err := ReadManifest(opts.RepoPath); err == nil {
manifest.Embeddings = previous.Embeddings
}
}
for _, table := range SnapshotTables {
entry, err := exportTable(ctx, s.DB(), dataDir, table)
entry, err := exportTable(ctx, s.DB(), opts.RepoPath, table)
if err != nil {
return Manifest{}, err
}
manifest.Tables = append(manifest.Tables, entry)
}
if opts.IncludeEmbeddings {
entry, err := exportEmbeddings(ctx, s.DB(), opts)
if err != nil {
return Manifest{}, err
}
manifest.Embeddings = []EmbeddingManifest{entry}
}
body, err := json.MarshalIndent(manifest, "", " ")
if err != nil {
return Manifest{}, err
@ -234,6 +260,11 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error)
return Manifest{}, err
}
}
if opts.IncludeEmbeddings {
if err := importEmbeddings(ctx, tx, opts, manifest.Embeddings); err != nil {
return Manifest{}, err
}
}
if err := tx.Commit(); err != nil {
return Manifest{}, err
}
@ -283,6 +314,11 @@ func ImportIfChanged(ctx context.Context, s *store.Store, opts Options) (Manifes
return Manifest{}, false, err
}
if ManifestAlreadyImported(ctx, s, manifest) {
if opts.IncludeEmbeddings {
if err := ImportEmbeddings(ctx, s, opts, manifest); err != nil {
return Manifest{}, false, err
}
}
if err := MarkImported(ctx, s, manifest); err != nil {
return Manifest{}, false, err
}
@ -295,6 +331,27 @@ func ImportIfChanged(ctx context.Context, s *store.Store, opts Options) (Manifes
return imported, true, nil
}
func ImportEmbeddings(ctx context.Context, s *store.Store, opts Options, manifest Manifest) error {
tx, err := s.DB().BeginTx(ctx, nil)
if err != nil {
return err
}
committed := false
defer func() {
if !committed {
_ = tx.Rollback()
}
}()
if err := importEmbeddings(ctx, tx, opts, manifest.Embeddings); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return err
}
committed = true
return nil
}
func ManifestAlreadyImported(ctx context.Context, s *store.Store, manifest Manifest) bool {
if manifest.GeneratedAt.IsZero() {
return false
@ -353,7 +410,7 @@ func NeedsImport(ctx context.Context, s *store.Store, staleAfter time.Duration)
return time.Since(t) >= staleAfter
}
func exportTable(ctx context.Context, db *sql.DB, dataDir, table string) (TableManifest, error) {
func exportTable(ctx context.Context, db *sql.DB, repoPath, table string) (TableManifest, error) {
query := "select * from " + table
rows, err := db.QueryContext(ctx, query)
if err != nil {
@ -364,11 +421,11 @@ func exportTable(ctx context.Context, db *sql.DB, dataDir, table string) (TableM
if err != nil {
return TableManifest{}, fmt.Errorf("columns %s: %w", table, err)
}
tableDir := filepath.Join(dataDir, table)
tableDir := filepath.Join(repoPath, "tables", table)
if err := os.MkdirAll(tableDir, 0o755); err != nil {
return TableManifest{}, fmt.Errorf("mkdir %s: %w", table, err)
}
writer := tableShardWriter{dataDir: dataDir, table: table}
writer := tableShardWriter{rootDir: repoPath, relDir: filepath.ToSlash(filepath.Join("tables", table)), label: table}
if err := writer.open(); err != nil {
return TableManifest{}, err
}
@ -415,6 +472,95 @@ func exportTable(ctx context.Context, db *sql.DB, dataDir, table string) (TableM
return TableManifest{Name: table, Files: writer.files, Columns: columns, Rows: count}, nil
}
func exportEmbeddings(ctx context.Context, db *sql.DB, opts Options) (EmbeddingManifest, error) {
provider := strings.ToLower(strings.TrimSpace(opts.EmbeddingProvider))
model := strings.TrimSpace(opts.EmbeddingModel)
inputVersion := strings.TrimSpace(opts.EmbeddingInputVersion)
if inputVersion == "" {
inputVersion = store.EmbeddingInputVersion
}
if provider == "" || model == "" {
return EmbeddingManifest{}, fmt.Errorf("embedding provider and model are required")
}
relDir := filepath.ToSlash(filepath.Join("embeddings", safePathSegment(provider), safePathSegment(model), safePathSegment(inputVersion)))
if err := os.RemoveAll(filepath.Join(opts.RepoPath, "embeddings")); err != nil {
return EmbeddingManifest{}, fmt.Errorf("reset embeddings dir: %w", err)
}
if err := os.MkdirAll(filepath.Join(opts.RepoPath, filepath.FromSlash(relDir)), 0o755); err != nil {
return EmbeddingManifest{}, fmt.Errorf("mkdir %s: %w", relDir, err)
}
rows, err := db.QueryContext(ctx, `
select message_id, provider, model, input_version, dimensions, embedding_blob, embedded_at
from message_embeddings
where provider = ? and model = ? and input_version = ?
order by message_id
`, provider, model, inputVersion)
if err != nil {
return EmbeddingManifest{}, fmt.Errorf("query message_embeddings: %w", err)
}
defer func() { _ = rows.Close() }()
writer := tableShardWriter{rootDir: opts.RepoPath, relDir: relDir, label: "message_embeddings"}
if err := writer.open(); err != nil {
return EmbeddingManifest{}, err
}
defer func() { _ = writer.close() }()
columns := []string{"message_id", "provider", "model", "input_version", "dimensions", "embedding_blob", "embedded_at"}
count := 0
for rows.Next() {
var (
messageID string
rowProv string
rowModel string
rowInput string
dimensions int
blob []byte
embeddedAt string
)
if err := rows.Scan(&messageID, &rowProv, &rowModel, &rowInput, &dimensions, &blob, &embeddedAt); err != nil {
return EmbeddingManifest{}, fmt.Errorf("scan message_embeddings: %w", err)
}
body, err := json.Marshal(map[string]any{
"message_id": messageID,
"provider": rowProv,
"model": rowModel,
"input_version": rowInput,
"dimensions": dimensions,
"embedding_blob": base64.StdEncoding.EncodeToString(blob),
"embedded_at": embeddedAt,
})
if err != nil {
return EmbeddingManifest{}, fmt.Errorf("marshal message_embeddings row: %w", err)
}
if err := writer.rotateIfNeeded(); err != nil {
return EmbeddingManifest{}, err
}
if _, err := writer.Write(body); err != nil {
return EmbeddingManifest{}, fmt.Errorf("write message_embeddings row: %w", err)
}
if _, err := writer.Write([]byte{'\n'}); err != nil {
return EmbeddingManifest{}, fmt.Errorf("write message_embeddings newline: %w", err)
}
count++
if err := writer.finishRow(); err != nil {
return EmbeddingManifest{}, err
}
}
if err := rows.Err(); err != nil {
return EmbeddingManifest{}, fmt.Errorf("iterate message_embeddings: %w", err)
}
if err := writer.close(); err != nil {
return EmbeddingManifest{}, err
}
return EmbeddingManifest{
Provider: provider,
Model: model,
InputVersion: inputVersion,
Files: writer.files,
Columns: columns,
Rows: count,
}, nil
}
func importTable(ctx context.Context, tx *sql.Tx, repoPath string, table TableManifest) error {
files := table.Files
if len(files) == 0 && strings.TrimSpace(table.File) != "" {
@ -470,9 +616,104 @@ func importTableFile(ctx context.Context, stmt *sql.Stmt, repoPath string, table
return nil
}
func importEmbeddings(ctx context.Context, tx *sql.Tx, opts Options, manifests []EmbeddingManifest) error {
if len(manifests) == 0 {
return nil
}
stmt, err := tx.PrepareContext(ctx, `
insert into message_embeddings(
message_id, provider, model, input_version, dimensions, embedding_blob, embedded_at
) values(?, ?, ?, ?, ?, ?, ?)
on conflict(message_id, provider, model, input_version) do update set
dimensions = excluded.dimensions,
embedding_blob = excluded.embedding_blob,
embedded_at = excluded.embedded_at
`)
if err != nil {
return fmt.Errorf("prepare import message_embeddings: %w", err)
}
defer func() { _ = stmt.Close() }()
for _, manifest := range manifests {
if !embeddingManifestMatches(opts, manifest) {
continue
}
files := manifest.Files
if len(files) == 0 {
return fmt.Errorf("embedding manifest %s/%s/%s has no files", manifest.Provider, manifest.Model, manifest.InputVersion)
}
for _, rel := range files {
if err := importEmbeddingFile(ctx, stmt, opts.RepoPath, rel); err != nil {
return err
}
}
}
return nil
}
func importEmbeddingFile(ctx context.Context, stmt *sql.Stmt, repoPath, rel string) error {
path := filepath.Join(repoPath, filepath.FromSlash(rel))
file, err := os.Open(path)
if err != nil {
return fmt.Errorf("open %s: %w", rel, err)
}
defer func() { _ = file.Close() }()
gz, err := gzip.NewReader(file)
if err != nil {
return fmt.Errorf("read gzip %s: %w", rel, err)
}
defer func() { _ = gz.Close() }()
dec := json.NewDecoder(gz)
dec.UseNumber()
for {
var row struct {
MessageID string `json:"message_id"`
Provider string `json:"provider"`
Model string `json:"model"`
InputVersion string `json:"input_version"`
Dimensions json.Number `json:"dimensions"`
EmbeddingBlob string `json:"embedding_blob"`
EmbeddedAt string `json:"embedded_at"`
}
err := dec.Decode(&row)
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return fmt.Errorf("decode %s: %w", rel, err)
}
dimensions, err := strconv.Atoi(row.Dimensions.String())
if err != nil {
return fmt.Errorf("decode dimensions in %s: %w", rel, err)
}
blob, err := base64.StdEncoding.DecodeString(row.EmbeddingBlob)
if err != nil {
return fmt.Errorf("decode embedding blob in %s: %w", rel, err)
}
if _, err := stmt.ExecContext(ctx, row.MessageID, row.Provider, row.Model, row.InputVersion, dimensions, blob, row.EmbeddedAt); err != nil {
return fmt.Errorf("insert message_embeddings: %w", err)
}
}
return nil
}
func embeddingManifestMatches(opts Options, manifest EmbeddingManifest) bool {
if strings.TrimSpace(opts.EmbeddingProvider) != "" && manifest.Provider != strings.ToLower(strings.TrimSpace(opts.EmbeddingProvider)) {
return false
}
if strings.TrimSpace(opts.EmbeddingModel) != "" && manifest.Model != strings.TrimSpace(opts.EmbeddingModel) {
return false
}
inputVersion := strings.TrimSpace(opts.EmbeddingInputVersion)
if inputVersion == "" {
inputVersion = store.EmbeddingInputVersion
}
return manifest.InputVersion == inputVersion
}
type tableShardWriter struct {
dataDir string
table string
rootDir string
relDir string
label string
nextShard int
rowsInShard int
files []string
@ -482,8 +723,8 @@ type tableShardWriter struct {
}
func (w *tableShardWriter) open() error {
rel := filepath.ToSlash(filepath.Join("tables", w.table, fmt.Sprintf("%06d.jsonl.gz", w.nextShard)))
path := filepath.Join(w.dataDir, w.table, fmt.Sprintf("%06d.jsonl.gz", w.nextShard))
rel := filepath.ToSlash(filepath.Join(w.relDir, fmt.Sprintf("%06d.jsonl.gz", w.nextShard)))
path := filepath.Join(w.rootDir, filepath.FromSlash(rel))
file, err := os.Create(path)
if err != nil {
return fmt.Errorf("create %s: %w", rel, err)
@ -517,7 +758,7 @@ func (w *tableShardWriter) finishRow() error {
return nil
}
if err := w.gz.Flush(); err != nil {
return fmt.Errorf("flush %s shard: %w", w.table, err)
return fmt.Errorf("flush %s shard: %w", w.label, err)
}
return nil
}
@ -537,7 +778,7 @@ func (w *tableShardWriter) close() error {
w.file = nil
}
if closeErr != nil {
return fmt.Errorf("close %s shard: %w", w.table, closeErr)
return fmt.Errorf("close %s shard: %w", w.label, closeErr)
}
return nil
}
@ -591,6 +832,29 @@ func quoteIdent(s string) string {
return `"` + strings.ReplaceAll(s, `"`, `""`) + `"`
}
func safePathSegment(s string) string {
s = strings.TrimSpace(s)
if s == "" {
return "_"
}
var b strings.Builder
for _, r := range s {
switch {
case r >= 'a' && r <= 'z':
b.WriteRune(r)
case r >= 'A' && r <= 'Z':
b.WriteRune(r)
case r >= '0' && r <= '9':
b.WriteRune(r)
case r == '-' || r == '_' || r == '.':
b.WriteRune(r)
default:
b.WriteRune('_')
}
}
return b.String()
}
func run(ctx context.Context, dir, name string, args ...string) error {
out, err := output(ctx, dir, name, args...)
if err != nil {

View File

@ -75,6 +75,7 @@ func TestSnapshotExcludesLocalEmbeddingState(t *testing.T) {
require.NoError(t, err)
require.NotContains(t, tableNames(manifest), "embedding_jobs")
require.NotContains(t, tableNames(manifest), "message_embeddings")
require.Empty(t, manifest.Embeddings)
dst, err := store.Open(ctx, filepath.Join(t.TempDir(), "dst.db"))
require.NoError(t, err)
@ -94,6 +95,60 @@ func TestSnapshotExcludesLocalEmbeddingState(t *testing.T) {
require.Equal(t, "pending", state)
}
func TestExportImportEmbeddingsOptIn(t *testing.T) {
ctx := context.Background()
src := seedStore(t, filepath.Join(t.TempDir(), "src.db"))
defer func() { _ = src.Close() }()
vector := []float32{1, 0.5}
blob, err := store.EncodeEmbeddingVector(vector)
require.NoError(t, err)
embeddedAt := time.Now().UTC().Format(time.RFC3339Nano)
_, err = src.DB().ExecContext(ctx, `
insert into message_embeddings(
message_id, provider, model, input_version, dimensions, embedding_blob, embedded_at
) values ('m1', 'openai', 'text-embedding-3-small', ?, 2, ?, ?)
`, store.EmbeddingInputVersion, blob, embeddedAt)
require.NoError(t, err)
repo := filepath.Join(t.TempDir(), "share")
opts := Options{
RepoPath: repo,
Branch: "main",
IncludeEmbeddings: true,
EmbeddingProvider: "openai",
EmbeddingModel: "text-embedding-3-small",
EmbeddingInputVersion: store.EmbeddingInputVersion,
}
manifest, err := Export(ctx, src, opts)
require.NoError(t, err)
require.Len(t, manifest.Embeddings, 1)
require.Equal(t, 1, manifest.Embeddings[0].Rows)
require.NotEmpty(t, manifest.Embeddings[0].Files)
require.FileExists(t, filepath.Join(repo, filepath.FromSlash(manifest.Embeddings[0].Files[0])))
dst, err := store.Open(ctx, filepath.Join(t.TempDir(), "dst.db"))
require.NoError(t, err)
defer func() { _ = dst.Close() }()
_, err = Import(ctx, dst, opts)
require.NoError(t, err)
var gotBlob []byte
var gotDimensions int
require.NoError(t, dst.DB().QueryRowContext(ctx, `
select dimensions, embedding_blob
from message_embeddings
where message_id = 'm1'
and provider = 'openai'
and model = 'text-embedding-3-small'
and input_version = ?
`, store.EmbeddingInputVersion).Scan(&gotDimensions, &gotBlob))
require.Equal(t, 2, gotDimensions)
gotVector, err := store.DecodeEmbeddingVector(gotBlob)
require.NoError(t, err)
require.Equal(t, vector, gotVector)
}
func TestImportIfChangedSkipsSameManifest(t *testing.T) {
ctx := context.Background()
src := seedStore(t, filepath.Join(t.TempDir(), "src.db"))