feat: back up embeddings in git snapshots
This commit is contained in:
parent
4437514537
commit
413d481c3e
@ -13,6 +13,7 @@ All notable changes to `discrawl` will be documented in this file.
|
||||
- semantic message search now ranks across the full compatible local vector set instead of only the newest candidate window. (#36) Thanks @GaosCode.
|
||||
- hybrid message search now fuses FTS with local semantic vectors while avoiding embedding-provider calls when no local vectors exist. (#37) Thanks @GaosCode.
|
||||
- docs now cover semantic and hybrid search setup, embedding privacy, Git snapshot behavior, and local vector rebuilds. (#39) Thanks @GaosCode.
|
||||
- Git snapshot publishing can now opt in to backing up generated embedding vectors with `--with-embeddings` while still keeping embedding queue state local.
|
||||
|
||||
## 0.3.0 - 2026-04-21
|
||||
|
||||
|
||||
@ -377,6 +377,15 @@ Hybrid mode is supported too: keep normal Discord credentials configured and set
|
||||
|
||||
Git snapshots publish archive tables only. Embedding queue state and generated vectors stay local to each machine. Git-only readers can use FTS immediately. To use semantic or hybrid search with semantic recall, configure a local embedding provider and run `discrawl embed --rebuild`. Hybrid search falls back to FTS when no local vectors exist.
|
||||
|
||||
If you want to back up generated vectors too, publish them explicitly:
|
||||
|
||||
```bash
|
||||
discrawl publish --with-embeddings --push
|
||||
discrawl update --with-embeddings
|
||||
```
|
||||
|
||||
`--with-embeddings` exports and imports stored vectors for the configured `[search.embeddings]` provider/model/input version. It never exports `embedding_jobs`.
|
||||
|
||||
The Docker smoke test installs `discrawl` in a clean Go container, subscribes to a Git snapshot repo, then checks `search`, `messages`, `sql`, and `report`:
|
||||
|
||||
```bash
|
||||
|
||||
@ -24,6 +24,7 @@ func (r *runtime) runPublish(args []string) error {
|
||||
readmePath := fs.String("readme", "", "")
|
||||
noCommit := fs.Bool("no-commit", false, "")
|
||||
push := fs.Bool("push", false, "")
|
||||
withEmbeddings := fs.Bool("with-embeddings", false, "")
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return usageErr(err)
|
||||
}
|
||||
@ -34,6 +35,9 @@ func (r *runtime) runPublish(args []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if *withEmbeddings {
|
||||
applyEmbeddingShareOptions(&opts, r.cfg)
|
||||
}
|
||||
manifest, err := share.Export(r.ctx, r.store, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -75,6 +79,7 @@ func (r *runtime) runPublish(args []string) error {
|
||||
"remote": opts.Remote,
|
||||
"generated_at": manifest.GeneratedAt,
|
||||
"tables": manifest.Tables,
|
||||
"embeddings": manifest.Embeddings,
|
||||
"readme": *readmePath,
|
||||
"committed": committed,
|
||||
"pushed": *push,
|
||||
@ -89,6 +94,7 @@ func (r *runtime) runSubscribe(args []string) error {
|
||||
staleAfter := fs.String("stale-after", "15m", "")
|
||||
noAutoUpdate := fs.Bool("no-auto-update", false, "")
|
||||
noImport := fs.Bool("no-import", false, "")
|
||||
withEmbeddings := fs.Bool("with-embeddings", false, "")
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return usageErr(err)
|
||||
}
|
||||
@ -134,6 +140,9 @@ func (r *runtime) runSubscribe(args []string) error {
|
||||
return configErr(err)
|
||||
}
|
||||
opts := share.Options{RepoPath: expandedRepo, Remote: cfg.Share.Remote, Branch: cfg.Share.Branch}
|
||||
if *withEmbeddings {
|
||||
applyEmbeddingShareOptions(&opts, cfg)
|
||||
}
|
||||
if err := share.Pull(r.ctx, opts); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -147,6 +156,7 @@ func (r *runtime) runSubscribe(args []string) error {
|
||||
"remote": opts.Remote,
|
||||
"generated_at": manifest.GeneratedAt,
|
||||
"tables": manifest.Tables,
|
||||
"embeddings": manifest.Embeddings,
|
||||
"imported": imported,
|
||||
})
|
||||
}
|
||||
@ -157,6 +167,7 @@ func (r *runtime) runUpdate(args []string) error {
|
||||
repoPath := fs.String("repo", r.cfg.Share.RepoPath, "")
|
||||
remote := fs.String("remote", r.cfg.Share.Remote, "")
|
||||
branch := fs.String("branch", r.cfg.Share.Branch, "")
|
||||
withEmbeddings := fs.Bool("with-embeddings", false, "")
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return usageErr(err)
|
||||
}
|
||||
@ -167,6 +178,9 @@ func (r *runtime) runUpdate(args []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if *withEmbeddings {
|
||||
applyEmbeddingShareOptions(&opts, r.cfg)
|
||||
}
|
||||
if err := share.Pull(r.ctx, opts); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -179,6 +193,7 @@ func (r *runtime) runUpdate(args []string) error {
|
||||
"remote": opts.Remote,
|
||||
"generated_at": manifest.GeneratedAt,
|
||||
"tables": manifest.Tables,
|
||||
"embeddings": manifest.Embeddings,
|
||||
"imported": imported,
|
||||
})
|
||||
}
|
||||
@ -197,6 +212,13 @@ func shareOptionsFromFlags(repoPath, remote, branch string) (share.Options, erro
|
||||
return share.Options{RepoPath: expandedRepo, Remote: remote, Branch: branch}, nil
|
||||
}
|
||||
|
||||
func applyEmbeddingShareOptions(opts *share.Options, cfg config.Config) {
|
||||
opts.IncludeEmbeddings = true
|
||||
opts.EmbeddingProvider = cfg.Search.Embeddings.Provider
|
||||
opts.EmbeddingModel = cfg.Search.Embeddings.Model
|
||||
opts.EmbeddingInputVersion = store.EmbeddingInputVersion
|
||||
}
|
||||
|
||||
func loadConfigOrDefault(path string) (config.Config, error) {
|
||||
cfg, err := config.Load(path)
|
||||
if err == nil {
|
||||
|
||||
@ -4,6 +4,7 @@ import (
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -42,16 +43,21 @@ var SnapshotTables = []string{
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
RepoPath string
|
||||
Remote string
|
||||
Branch string
|
||||
RepoPath string
|
||||
Remote string
|
||||
Branch string
|
||||
IncludeEmbeddings bool
|
||||
EmbeddingProvider string
|
||||
EmbeddingModel string
|
||||
EmbeddingInputVersion string
|
||||
}
|
||||
|
||||
type Manifest struct {
|
||||
Version int `json:"version"`
|
||||
GeneratedAt time.Time `json:"generated_at"`
|
||||
Tables []TableManifest `json:"tables"`
|
||||
Files map[string]string `json:"files,omitempty"`
|
||||
Version int `json:"version"`
|
||||
GeneratedAt time.Time `json:"generated_at"`
|
||||
Tables []TableManifest `json:"tables"`
|
||||
Embeddings []EmbeddingManifest `json:"embeddings,omitempty"`
|
||||
Files map[string]string `json:"files,omitempty"`
|
||||
}
|
||||
|
||||
type TableManifest struct {
|
||||
@ -62,6 +68,15 @@ type TableManifest struct {
|
||||
Rows int `json:"rows"`
|
||||
}
|
||||
|
||||
type EmbeddingManifest struct {
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
InputVersion string `json:"input_version"`
|
||||
Files []string `json:"files"`
|
||||
Columns []string `json:"columns"`
|
||||
Rows int `json:"rows"`
|
||||
}
|
||||
|
||||
func EnsureRepo(ctx context.Context, opts Options) error {
|
||||
if strings.TrimSpace(opts.RepoPath) == "" {
|
||||
return fmt.Errorf("share repo path is empty")
|
||||
@ -163,11 +178,10 @@ func Export(ctx context.Context, s *store.Store, opts Options) (Manifest, error)
|
||||
if err := EnsureRepo(ctx, opts); err != nil {
|
||||
return Manifest{}, err
|
||||
}
|
||||
dataDir := filepath.Join(opts.RepoPath, "tables")
|
||||
if err := os.RemoveAll(dataDir); err != nil {
|
||||
if err := os.RemoveAll(filepath.Join(opts.RepoPath, "tables")); err != nil {
|
||||
return Manifest{}, fmt.Errorf("reset tables dir: %w", err)
|
||||
}
|
||||
if err := os.MkdirAll(dataDir, 0o755); err != nil {
|
||||
if err := os.MkdirAll(filepath.Join(opts.RepoPath, "tables"), 0o755); err != nil {
|
||||
return Manifest{}, fmt.Errorf("mkdir tables dir: %w", err)
|
||||
}
|
||||
manifest := Manifest{
|
||||
@ -175,13 +189,25 @@ func Export(ctx context.Context, s *store.Store, opts Options) (Manifest, error)
|
||||
GeneratedAt: time.Now().UTC(),
|
||||
Files: map[string]string{"manifest": ManifestName},
|
||||
}
|
||||
if !opts.IncludeEmbeddings {
|
||||
if previous, err := ReadManifest(opts.RepoPath); err == nil {
|
||||
manifest.Embeddings = previous.Embeddings
|
||||
}
|
||||
}
|
||||
for _, table := range SnapshotTables {
|
||||
entry, err := exportTable(ctx, s.DB(), dataDir, table)
|
||||
entry, err := exportTable(ctx, s.DB(), opts.RepoPath, table)
|
||||
if err != nil {
|
||||
return Manifest{}, err
|
||||
}
|
||||
manifest.Tables = append(manifest.Tables, entry)
|
||||
}
|
||||
if opts.IncludeEmbeddings {
|
||||
entry, err := exportEmbeddings(ctx, s.DB(), opts)
|
||||
if err != nil {
|
||||
return Manifest{}, err
|
||||
}
|
||||
manifest.Embeddings = []EmbeddingManifest{entry}
|
||||
}
|
||||
body, err := json.MarshalIndent(manifest, "", " ")
|
||||
if err != nil {
|
||||
return Manifest{}, err
|
||||
@ -234,6 +260,11 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error)
|
||||
return Manifest{}, err
|
||||
}
|
||||
}
|
||||
if opts.IncludeEmbeddings {
|
||||
if err := importEmbeddings(ctx, tx, opts, manifest.Embeddings); err != nil {
|
||||
return Manifest{}, err
|
||||
}
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return Manifest{}, err
|
||||
}
|
||||
@ -283,6 +314,11 @@ func ImportIfChanged(ctx context.Context, s *store.Store, opts Options) (Manifes
|
||||
return Manifest{}, false, err
|
||||
}
|
||||
if ManifestAlreadyImported(ctx, s, manifest) {
|
||||
if opts.IncludeEmbeddings {
|
||||
if err := ImportEmbeddings(ctx, s, opts, manifest); err != nil {
|
||||
return Manifest{}, false, err
|
||||
}
|
||||
}
|
||||
if err := MarkImported(ctx, s, manifest); err != nil {
|
||||
return Manifest{}, false, err
|
||||
}
|
||||
@ -295,6 +331,27 @@ func ImportIfChanged(ctx context.Context, s *store.Store, opts Options) (Manifes
|
||||
return imported, true, nil
|
||||
}
|
||||
|
||||
func ImportEmbeddings(ctx context.Context, s *store.Store, opts Options, manifest Manifest) error {
|
||||
tx, err := s.DB().BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
committed := false
|
||||
defer func() {
|
||||
if !committed {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}()
|
||||
if err := importEmbeddings(ctx, tx, opts, manifest.Embeddings); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
committed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func ManifestAlreadyImported(ctx context.Context, s *store.Store, manifest Manifest) bool {
|
||||
if manifest.GeneratedAt.IsZero() {
|
||||
return false
|
||||
@ -353,7 +410,7 @@ func NeedsImport(ctx context.Context, s *store.Store, staleAfter time.Duration)
|
||||
return time.Since(t) >= staleAfter
|
||||
}
|
||||
|
||||
func exportTable(ctx context.Context, db *sql.DB, dataDir, table string) (TableManifest, error) {
|
||||
func exportTable(ctx context.Context, db *sql.DB, repoPath, table string) (TableManifest, error) {
|
||||
query := "select * from " + table
|
||||
rows, err := db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
@ -364,11 +421,11 @@ func exportTable(ctx context.Context, db *sql.DB, dataDir, table string) (TableM
|
||||
if err != nil {
|
||||
return TableManifest{}, fmt.Errorf("columns %s: %w", table, err)
|
||||
}
|
||||
tableDir := filepath.Join(dataDir, table)
|
||||
tableDir := filepath.Join(repoPath, "tables", table)
|
||||
if err := os.MkdirAll(tableDir, 0o755); err != nil {
|
||||
return TableManifest{}, fmt.Errorf("mkdir %s: %w", table, err)
|
||||
}
|
||||
writer := tableShardWriter{dataDir: dataDir, table: table}
|
||||
writer := tableShardWriter{rootDir: repoPath, relDir: filepath.ToSlash(filepath.Join("tables", table)), label: table}
|
||||
if err := writer.open(); err != nil {
|
||||
return TableManifest{}, err
|
||||
}
|
||||
@ -415,6 +472,95 @@ func exportTable(ctx context.Context, db *sql.DB, dataDir, table string) (TableM
|
||||
return TableManifest{Name: table, Files: writer.files, Columns: columns, Rows: count}, nil
|
||||
}
|
||||
|
||||
func exportEmbeddings(ctx context.Context, db *sql.DB, opts Options) (EmbeddingManifest, error) {
|
||||
provider := strings.ToLower(strings.TrimSpace(opts.EmbeddingProvider))
|
||||
model := strings.TrimSpace(opts.EmbeddingModel)
|
||||
inputVersion := strings.TrimSpace(opts.EmbeddingInputVersion)
|
||||
if inputVersion == "" {
|
||||
inputVersion = store.EmbeddingInputVersion
|
||||
}
|
||||
if provider == "" || model == "" {
|
||||
return EmbeddingManifest{}, fmt.Errorf("embedding provider and model are required")
|
||||
}
|
||||
relDir := filepath.ToSlash(filepath.Join("embeddings", safePathSegment(provider), safePathSegment(model), safePathSegment(inputVersion)))
|
||||
if err := os.RemoveAll(filepath.Join(opts.RepoPath, "embeddings")); err != nil {
|
||||
return EmbeddingManifest{}, fmt.Errorf("reset embeddings dir: %w", err)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Join(opts.RepoPath, filepath.FromSlash(relDir)), 0o755); err != nil {
|
||||
return EmbeddingManifest{}, fmt.Errorf("mkdir %s: %w", relDir, err)
|
||||
}
|
||||
rows, err := db.QueryContext(ctx, `
|
||||
select message_id, provider, model, input_version, dimensions, embedding_blob, embedded_at
|
||||
from message_embeddings
|
||||
where provider = ? and model = ? and input_version = ?
|
||||
order by message_id
|
||||
`, provider, model, inputVersion)
|
||||
if err != nil {
|
||||
return EmbeddingManifest{}, fmt.Errorf("query message_embeddings: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
writer := tableShardWriter{rootDir: opts.RepoPath, relDir: relDir, label: "message_embeddings"}
|
||||
if err := writer.open(); err != nil {
|
||||
return EmbeddingManifest{}, err
|
||||
}
|
||||
defer func() { _ = writer.close() }()
|
||||
columns := []string{"message_id", "provider", "model", "input_version", "dimensions", "embedding_blob", "embedded_at"}
|
||||
count := 0
|
||||
for rows.Next() {
|
||||
var (
|
||||
messageID string
|
||||
rowProv string
|
||||
rowModel string
|
||||
rowInput string
|
||||
dimensions int
|
||||
blob []byte
|
||||
embeddedAt string
|
||||
)
|
||||
if err := rows.Scan(&messageID, &rowProv, &rowModel, &rowInput, &dimensions, &blob, &embeddedAt); err != nil {
|
||||
return EmbeddingManifest{}, fmt.Errorf("scan message_embeddings: %w", err)
|
||||
}
|
||||
body, err := json.Marshal(map[string]any{
|
||||
"message_id": messageID,
|
||||
"provider": rowProv,
|
||||
"model": rowModel,
|
||||
"input_version": rowInput,
|
||||
"dimensions": dimensions,
|
||||
"embedding_blob": base64.StdEncoding.EncodeToString(blob),
|
||||
"embedded_at": embeddedAt,
|
||||
})
|
||||
if err != nil {
|
||||
return EmbeddingManifest{}, fmt.Errorf("marshal message_embeddings row: %w", err)
|
||||
}
|
||||
if err := writer.rotateIfNeeded(); err != nil {
|
||||
return EmbeddingManifest{}, err
|
||||
}
|
||||
if _, err := writer.Write(body); err != nil {
|
||||
return EmbeddingManifest{}, fmt.Errorf("write message_embeddings row: %w", err)
|
||||
}
|
||||
if _, err := writer.Write([]byte{'\n'}); err != nil {
|
||||
return EmbeddingManifest{}, fmt.Errorf("write message_embeddings newline: %w", err)
|
||||
}
|
||||
count++
|
||||
if err := writer.finishRow(); err != nil {
|
||||
return EmbeddingManifest{}, err
|
||||
}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return EmbeddingManifest{}, fmt.Errorf("iterate message_embeddings: %w", err)
|
||||
}
|
||||
if err := writer.close(); err != nil {
|
||||
return EmbeddingManifest{}, err
|
||||
}
|
||||
return EmbeddingManifest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
InputVersion: inputVersion,
|
||||
Files: writer.files,
|
||||
Columns: columns,
|
||||
Rows: count,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func importTable(ctx context.Context, tx *sql.Tx, repoPath string, table TableManifest) error {
|
||||
files := table.Files
|
||||
if len(files) == 0 && strings.TrimSpace(table.File) != "" {
|
||||
@ -470,9 +616,104 @@ func importTableFile(ctx context.Context, stmt *sql.Stmt, repoPath string, table
|
||||
return nil
|
||||
}
|
||||
|
||||
func importEmbeddings(ctx context.Context, tx *sql.Tx, opts Options, manifests []EmbeddingManifest) error {
|
||||
if len(manifests) == 0 {
|
||||
return nil
|
||||
}
|
||||
stmt, err := tx.PrepareContext(ctx, `
|
||||
insert into message_embeddings(
|
||||
message_id, provider, model, input_version, dimensions, embedding_blob, embedded_at
|
||||
) values(?, ?, ?, ?, ?, ?, ?)
|
||||
on conflict(message_id, provider, model, input_version) do update set
|
||||
dimensions = excluded.dimensions,
|
||||
embedding_blob = excluded.embedding_blob,
|
||||
embedded_at = excluded.embedded_at
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("prepare import message_embeddings: %w", err)
|
||||
}
|
||||
defer func() { _ = stmt.Close() }()
|
||||
for _, manifest := range manifests {
|
||||
if !embeddingManifestMatches(opts, manifest) {
|
||||
continue
|
||||
}
|
||||
files := manifest.Files
|
||||
if len(files) == 0 {
|
||||
return fmt.Errorf("embedding manifest %s/%s/%s has no files", manifest.Provider, manifest.Model, manifest.InputVersion)
|
||||
}
|
||||
for _, rel := range files {
|
||||
if err := importEmbeddingFile(ctx, stmt, opts.RepoPath, rel); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func importEmbeddingFile(ctx context.Context, stmt *sql.Stmt, repoPath, rel string) error {
|
||||
path := filepath.Join(repoPath, filepath.FromSlash(rel))
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open %s: %w", rel, err)
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
gz, err := gzip.NewReader(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read gzip %s: %w", rel, err)
|
||||
}
|
||||
defer func() { _ = gz.Close() }()
|
||||
dec := json.NewDecoder(gz)
|
||||
dec.UseNumber()
|
||||
for {
|
||||
var row struct {
|
||||
MessageID string `json:"message_id"`
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
InputVersion string `json:"input_version"`
|
||||
Dimensions json.Number `json:"dimensions"`
|
||||
EmbeddingBlob string `json:"embedding_blob"`
|
||||
EmbeddedAt string `json:"embedded_at"`
|
||||
}
|
||||
err := dec.Decode(&row)
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("decode %s: %w", rel, err)
|
||||
}
|
||||
dimensions, err := strconv.Atoi(row.Dimensions.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("decode dimensions in %s: %w", rel, err)
|
||||
}
|
||||
blob, err := base64.StdEncoding.DecodeString(row.EmbeddingBlob)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decode embedding blob in %s: %w", rel, err)
|
||||
}
|
||||
if _, err := stmt.ExecContext(ctx, row.MessageID, row.Provider, row.Model, row.InputVersion, dimensions, blob, row.EmbeddedAt); err != nil {
|
||||
return fmt.Errorf("insert message_embeddings: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func embeddingManifestMatches(opts Options, manifest EmbeddingManifest) bool {
|
||||
if strings.TrimSpace(opts.EmbeddingProvider) != "" && manifest.Provider != strings.ToLower(strings.TrimSpace(opts.EmbeddingProvider)) {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(opts.EmbeddingModel) != "" && manifest.Model != strings.TrimSpace(opts.EmbeddingModel) {
|
||||
return false
|
||||
}
|
||||
inputVersion := strings.TrimSpace(opts.EmbeddingInputVersion)
|
||||
if inputVersion == "" {
|
||||
inputVersion = store.EmbeddingInputVersion
|
||||
}
|
||||
return manifest.InputVersion == inputVersion
|
||||
}
|
||||
|
||||
type tableShardWriter struct {
|
||||
dataDir string
|
||||
table string
|
||||
rootDir string
|
||||
relDir string
|
||||
label string
|
||||
nextShard int
|
||||
rowsInShard int
|
||||
files []string
|
||||
@ -482,8 +723,8 @@ type tableShardWriter struct {
|
||||
}
|
||||
|
||||
func (w *tableShardWriter) open() error {
|
||||
rel := filepath.ToSlash(filepath.Join("tables", w.table, fmt.Sprintf("%06d.jsonl.gz", w.nextShard)))
|
||||
path := filepath.Join(w.dataDir, w.table, fmt.Sprintf("%06d.jsonl.gz", w.nextShard))
|
||||
rel := filepath.ToSlash(filepath.Join(w.relDir, fmt.Sprintf("%06d.jsonl.gz", w.nextShard)))
|
||||
path := filepath.Join(w.rootDir, filepath.FromSlash(rel))
|
||||
file, err := os.Create(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create %s: %w", rel, err)
|
||||
@ -517,7 +758,7 @@ func (w *tableShardWriter) finishRow() error {
|
||||
return nil
|
||||
}
|
||||
if err := w.gz.Flush(); err != nil {
|
||||
return fmt.Errorf("flush %s shard: %w", w.table, err)
|
||||
return fmt.Errorf("flush %s shard: %w", w.label, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -537,7 +778,7 @@ func (w *tableShardWriter) close() error {
|
||||
w.file = nil
|
||||
}
|
||||
if closeErr != nil {
|
||||
return fmt.Errorf("close %s shard: %w", w.table, closeErr)
|
||||
return fmt.Errorf("close %s shard: %w", w.label, closeErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -591,6 +832,29 @@ func quoteIdent(s string) string {
|
||||
return `"` + strings.ReplaceAll(s, `"`, `""`) + `"`
|
||||
}
|
||||
|
||||
func safePathSegment(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return "_"
|
||||
}
|
||||
var b strings.Builder
|
||||
for _, r := range s {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
b.WriteRune(r)
|
||||
case r >= 'A' && r <= 'Z':
|
||||
b.WriteRune(r)
|
||||
case r >= '0' && r <= '9':
|
||||
b.WriteRune(r)
|
||||
case r == '-' || r == '_' || r == '.':
|
||||
b.WriteRune(r)
|
||||
default:
|
||||
b.WriteRune('_')
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func run(ctx context.Context, dir, name string, args ...string) error {
|
||||
out, err := output(ctx, dir, name, args...)
|
||||
if err != nil {
|
||||
|
||||
@ -75,6 +75,7 @@ func TestSnapshotExcludesLocalEmbeddingState(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, tableNames(manifest), "embedding_jobs")
|
||||
require.NotContains(t, tableNames(manifest), "message_embeddings")
|
||||
require.Empty(t, manifest.Embeddings)
|
||||
|
||||
dst, err := store.Open(ctx, filepath.Join(t.TempDir(), "dst.db"))
|
||||
require.NoError(t, err)
|
||||
@ -94,6 +95,60 @@ func TestSnapshotExcludesLocalEmbeddingState(t *testing.T) {
|
||||
require.Equal(t, "pending", state)
|
||||
}
|
||||
|
||||
func TestExportImportEmbeddingsOptIn(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
src := seedStore(t, filepath.Join(t.TempDir(), "src.db"))
|
||||
defer func() { _ = src.Close() }()
|
||||
|
||||
vector := []float32{1, 0.5}
|
||||
blob, err := store.EncodeEmbeddingVector(vector)
|
||||
require.NoError(t, err)
|
||||
embeddedAt := time.Now().UTC().Format(time.RFC3339Nano)
|
||||
_, err = src.DB().ExecContext(ctx, `
|
||||
insert into message_embeddings(
|
||||
message_id, provider, model, input_version, dimensions, embedding_blob, embedded_at
|
||||
) values ('m1', 'openai', 'text-embedding-3-small', ?, 2, ?, ?)
|
||||
`, store.EmbeddingInputVersion, blob, embeddedAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
repo := filepath.Join(t.TempDir(), "share")
|
||||
opts := Options{
|
||||
RepoPath: repo,
|
||||
Branch: "main",
|
||||
IncludeEmbeddings: true,
|
||||
EmbeddingProvider: "openai",
|
||||
EmbeddingModel: "text-embedding-3-small",
|
||||
EmbeddingInputVersion: store.EmbeddingInputVersion,
|
||||
}
|
||||
manifest, err := Export(ctx, src, opts)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, manifest.Embeddings, 1)
|
||||
require.Equal(t, 1, manifest.Embeddings[0].Rows)
|
||||
require.NotEmpty(t, manifest.Embeddings[0].Files)
|
||||
require.FileExists(t, filepath.Join(repo, filepath.FromSlash(manifest.Embeddings[0].Files[0])))
|
||||
|
||||
dst, err := store.Open(ctx, filepath.Join(t.TempDir(), "dst.db"))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = dst.Close() }()
|
||||
_, err = Import(ctx, dst, opts)
|
||||
require.NoError(t, err)
|
||||
|
||||
var gotBlob []byte
|
||||
var gotDimensions int
|
||||
require.NoError(t, dst.DB().QueryRowContext(ctx, `
|
||||
select dimensions, embedding_blob
|
||||
from message_embeddings
|
||||
where message_id = 'm1'
|
||||
and provider = 'openai'
|
||||
and model = 'text-embedding-3-small'
|
||||
and input_version = ?
|
||||
`, store.EmbeddingInputVersion).Scan(&gotDimensions, &gotBlob))
|
||||
require.Equal(t, 2, gotDimensions)
|
||||
gotVector, err := store.DecodeEmbeddingVector(gotBlob)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, vector, gotVector)
|
||||
}
|
||||
|
||||
func TestImportIfChangedSkipsSameManifest(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
src := seedStore(t, filepath.Join(t.TempDir(), "src.db"))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user