Compare commits

...

8 Commits
v0.4.1 ... main

Author SHA1 Message Date
Peter Steinberger
55196d74e0
feat: add shared encrypted backup helpers
Some checks failed
CI / test (push) Has been cancelled
CodeQL / analyze (push) Has been cancelled
Security Gate: Secret Scanning / Scan for Verified Secrets (push) Has been cancelled
2026-05-08 16:41:44 +01:00
Peter Steinberger
1cc2c66283
feat: add shared embedding and vector helpers 2026-05-08 09:56:40 +01:00
Peter Steinberger
7fbca35339
docs: prepare v0.4.2 release notes 2026-05-08 07:58:41 +01:00
Peter Steinberger
4d976d782b
feat(snapshot): add incremental shard import planning 2026-05-08 07:10:12 +01:00
Vincent Koc
6a2ae79aa6
docs: note dependency updates
Some checks are pending
CI / test (push) Waiting to run
CodeQL / analyze (push) Waiting to run
Security Gate: Secret Scanning / Scan for Verified Secrets (push) Waiting to run
2026-05-07 02:52:15 -07:00
dependabot[bot]
bbc8e09c07
build(deps): bump github.com/mattn/go-isatty from 0.0.20 to 0.0.22 (#4)
Bumps [github.com/mattn/go-isatty](https://github.com/mattn/go-isatty) from 0.0.20 to 0.0.22.
- [Commits](https://github.com/mattn/go-isatty/compare/v0.0.20...v0.0.22)

---
updated-dependencies:
- dependency-name: github.com/mattn/go-isatty
  dependency-version: 0.0.22
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-07 02:45:10 -07:00
dependabot[bot]
7fc82a3e08
build(deps): bump github.com/charmbracelet/x/ansi from 0.11.6 to 0.11.7 (#3)
Bumps [github.com/charmbracelet/x/ansi](https://github.com/charmbracelet/x) from 0.11.6 to 0.11.7.
- [Commits](https://github.com/charmbracelet/x/compare/ansi/v0.11.6...ansi/v0.11.7)

---
updated-dependencies:
- dependency-name: github.com/charmbracelet/x/ansi
  dependency-version: 0.11.7
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-07 02:40:44 -07:00
dependabot[bot]
9529b8547b
build(deps): bump github.com/pelletier/go-toml/v2 from 2.3.0 to 2.3.1 (#2)
Bumps [github.com/pelletier/go-toml/v2](https://github.com/pelletier/go-toml) from 2.3.0 to 2.3.1.
- [Release notes](https://github.com/pelletier/go-toml/releases)
- [Commits](https://github.com/pelletier/go-toml/compare/v2.3.0...v2.3.1)

---
updated-dependencies:
- dependency-name: github.com/pelletier/go-toml/v2
  dependency-version: 2.3.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-07 02:40:38 -07:00
18 changed files with 2312 additions and 57 deletions

View File

@ -41,7 +41,7 @@ GOWORK=off go test -count=1 ./...
For release readiness, also verify the public module tag:
```bash
GOPROXY=https://proxy.golang.org GONOSUMDB= go list -m github.com/vincentkoc/crawlkit@v0.4.0
GOPROXY=https://proxy.golang.org GONOSUMDB= go list -m github.com/openclaw/crawlkit@v0.5.0
```
## Downstream Compatibility

View File

@ -2,6 +2,23 @@
## Unreleased
## v0.5.1 - 2026-05-08
- Add reusable `backup` helpers for age identities, encrypted JSONL/Gzip shards,
manifests, recipient tracking, shard hash verification, and stale shard
cleanup.
- Add reusable `embed` providers for OpenAI, OpenAI-compatible endpoints,
Ollama, and llama.cpp, including probe diagnostics and rate-limit errors.
- Add reusable `vector` helpers for float32 blobs, dimension validation,
cosine similarity, top-k sorting, and reciprocal-rank fusion.
## v0.4.2 - 2026-05-08
- Add snapshot file fingerprints and an incremental import planner/executor so downstream apps can import changed JSONL/Gzip shards without deleting every table.
- Move the module path to `github.com/openclaw/crawlkit`.
- Bump routine Go module dependencies.
## v0.4.1 - 2026-05-06
- Add GitHub Sponsors funding metadata.

View File

@ -11,7 +11,7 @@ safe desktop-cache snapshot utilities.
## Install
```bash
go get github.com/vincentkoc/crawlkit@latest
go get github.com/openclaw/crawlkit@latest
```
Go packages are published by tagging this repository. There is no separate
@ -22,9 +22,12 @@ See `docs/boundary.md` for the crawlkit-versus-app ownership boundary.
- `config`: standard TOML config paths, runtime dirs, and token diagnostics.
- `store`: SQLite open/read-only/transaction/query helpers.
- `snapshot`: `manifest.json` plus JSONL/Gzip table snapshot export and import.
- `snapshot`: `manifest.json` plus JSONL/Gzip table snapshot export, file fingerprints, full import, and planned incremental shard import.
- `backup`: age-encrypted JSONL/Gzip shards, backup manifests, recipient/identity helpers, and shard restore verification.
- `mirror`: clone/init/pull/commit/push helpers for private snapshot repos.
- `state`: generic crawler cursor and freshness records.
- `embed`: reusable OpenAI-compatible, Ollama, and llama.cpp embedding providers plus local probe diagnostics.
- `vector`: float32 vector encoding, dimension validation, cosine scoring, top-k helpers, and reciprocal-rank fusion.
- `output`: text/json/log output helpers.
- `control`: crawl app metadata, command manifests, status payloads, and
database inventory for launchers and automation.

339
backup/backup.go Normal file
View File

@ -0,0 +1,339 @@
package backup
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"os"
"path"
"path/filepath"
"reflect"
"sort"
"strings"
"time"
)
const FormatVersion = 1
type Config struct {
Repo string
Identity string
Recipients []string
}
type Manifest struct {
Format int `json:"format"`
Encrypted bool `json:"encrypted"`
Exported time.Time `json:"exported"`
Recipients []string `json:"recipients,omitempty"`
Counts map[string]int `json:"counts"`
Shards []ShardEntry `json:"shards"`
}
type Shard struct {
Table string
Path string
Rows any
}
type ShardEntry struct {
Table string `json:"table"`
Path string `json:"path"`
Rows int `json:"rows"`
SHA256 string `json:"sha256"`
Bytes int64 `json:"bytes"`
}
type DecodedShard struct {
Entry ShardEntry
Plaintext []byte
}
func WriteSnapshot(ctx context.Context, cfg Config, shards []Shard, old Manifest) (Manifest, error) {
_ = ctx
recipients := normalizedStrings(cfg.Recipients)
reuseEncrypted := sameStrings(old.Recipients, recipients)
manifest := Manifest{
Format: FormatVersion,
Encrypted: true,
Exported: time.Now().UTC(),
Recipients: recipients,
Counts: map[string]int{},
}
for _, shard := range shards {
plaintext, rows, err := EncodeJSONL(shard.Rows)
if err != nil {
return Manifest{}, fmt.Errorf("encode %s: %w", shard.Table, err)
}
entry, err := WriteShard(cfg, old, shard.Table, shard.Path, plaintext, rows, reuseEncrypted)
if err != nil {
return Manifest{}, err
}
manifest.Counts[shard.Table] += rows
manifest.Shards = append(manifest.Shards, entry)
}
sort.Slice(manifest.Shards, func(i, j int) bool { return manifest.Shards[i].Path < manifest.Shards[j].Path })
if EquivalentManifest(old, manifest) {
return old, nil
}
if err := RemoveStaleShards(cfg.Repo, manifest.Shards); err != nil {
return Manifest{}, err
}
if err := WriteManifest(cfg.Repo, manifest); err != nil {
return Manifest{}, err
}
return manifest, nil
}
func ReadSnapshot(cfg Config, manifest Manifest) ([]DecodedShard, error) {
if manifest.Format != FormatVersion {
return nil, fmt.Errorf("unsupported backup format %d", manifest.Format)
}
var out []DecodedShard
for _, shard := range manifest.Shards {
plaintext, err := DecryptShardFile(cfg, shard)
if err != nil {
return nil, err
}
if got := SHA256Hex(plaintext); got != shard.SHA256 {
return nil, fmt.Errorf("backup shard hash mismatch for %s", shard.Path)
}
out = append(out, DecodedShard{Entry: shard, Plaintext: plaintext})
}
return out, nil
}
func WriteShard(cfg Config, old Manifest, table, rel string, plaintext []byte, rows int, reuseEncrypted bool) (ShardEntry, error) {
hash := SHA256Hex(plaintext)
target, err := ResolveShardPath(cfg.Repo, rel)
if err != nil {
return ShardEntry{}, err
}
if oldEntry, ok := old.Entry(rel); reuseEncrypted && ok && oldEntry.SHA256 == hash {
if info, err := os.Stat(target); err == nil {
oldEntry.Bytes = info.Size()
return oldEntry, nil
}
}
encrypted, _, err := EncryptShard(plaintext, cfg.Recipients)
if err != nil {
return ShardEntry{}, err
}
if err := os.MkdirAll(filepath.Dir(target), 0o700); err != nil {
return ShardEntry{}, err
}
if err := os.WriteFile(target, encrypted, 0o600); err != nil {
return ShardEntry{}, err
}
return ShardEntry{Table: table, Path: rel, Rows: rows, SHA256: hash, Bytes: int64(len(encrypted))}, nil
}
func DecryptShardFile(cfg Config, shard ShardEntry) ([]byte, error) {
target, err := ResolveShardPath(cfg.Repo, shard.Path)
if err != nil {
return nil, err
}
ciphertext, err := os.ReadFile(target) // #nosec G304 -- ResolveShardPath confines manifest-controlled paths below data/.
if err != nil {
return nil, err
}
return DecryptShard(ciphertext, cfg.Identity)
}
func ResolveShardPath(repo, rel string) (string, error) {
clean := path.Clean(strings.TrimSpace(rel))
if clean == "." || clean == ".." || strings.HasPrefix(clean, "../") || path.IsAbs(clean) {
return "", fmt.Errorf("backup shard path escapes backup root: %s", rel)
}
if !strings.HasPrefix(clean, "data/") || !strings.HasSuffix(clean, ".age") {
return "", fmt.Errorf("invalid backup shard path: %s", rel)
}
full := filepath.Join(repo, filepath.FromSlash(clean))
root := filepath.Clean(filepath.Join(repo, "data"))
parent := filepath.Clean(filepath.Dir(full))
if parent != root && !strings.HasPrefix(parent, root+string(filepath.Separator)) {
return "", fmt.Errorf("backup shard path escapes backup root: %s", rel)
}
return full, nil
}
func EncodeJSONL(rows any) ([]byte, int, error) {
value := reflect.ValueOf(rows)
if value.Kind() != reflect.Slice {
return nil, 0, fmt.Errorf("unsupported JSONL rows %T", rows)
}
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
for i := 0; i < value.Len(); i++ {
if err := enc.Encode(value.Index(i).Interface()); err != nil {
return nil, 0, err
}
}
return buf.Bytes(), value.Len(), nil
}
func DecodeJSONL[T any](plaintext []byte, out *[]T) error {
scanner := bufio.NewScanner(bytes.NewReader(plaintext))
scanner.Buffer(make([]byte, 0, 64*1024), 16*1024*1024)
for scanner.Scan() {
var value T
if err := json.Unmarshal(scanner.Bytes(), &value); err != nil {
return err
}
*out = append(*out, value)
}
return scanner.Err()
}
func ReadManifest(repo string) (Manifest, error) {
data, err := os.ReadFile(filepath.Join(repo, "manifest.json")) // #nosec G304 -- repo is configured by caller.
if err != nil {
return Manifest{}, err
}
var manifest Manifest
if err := json.Unmarshal(data, &manifest); err != nil {
return Manifest{}, err
}
return manifest, nil
}
func WriteManifest(repo string, manifest Manifest) error {
data, err := json.MarshalIndent(manifest, "", " ")
if err != nil {
return err
}
data = append(data, '\n')
return os.WriteFile(filepath.Join(repo, "manifest.json"), data, 0o600)
}
func (m Manifest) Entry(path string) (ShardEntry, bool) {
for _, shard := range m.Shards {
if shard.Path == path {
return shard, true
}
}
return ShardEntry{}, false
}
func EquivalentManifest(a, b Manifest) bool {
if a.Format != b.Format || a.Encrypted != b.Encrypted || !sameStrings(a.Recipients, b.Recipients) || !sameCounts(a.Counts, b.Counts) || len(a.Shards) != len(b.Shards) {
return false
}
for i := range a.Shards {
left, right := a.Shards[i], b.Shards[i]
left.Bytes, right.Bytes = 0, 0
if left != right {
return false
}
}
return true
}
func RemoveStaleShards(repo string, shards []ShardEntry) error {
keep := map[string]struct{}{}
for _, shard := range shards {
keep[filepath.Clean(filepath.Join(repo, filepath.FromSlash(shard.Path)))] = struct{}{}
}
root := filepath.Join(repo, "data")
if _, err := os.Stat(root); os.IsNotExist(err) {
return nil
}
var stale []string
if err := filepath.WalkDir(root, func(path string, d os.DirEntry, err error) error {
if err != nil || d == nil || d.IsDir() {
return err
}
if !strings.HasSuffix(path, ".age") {
return nil
}
clean := filepath.Clean(path)
if _, ok := keep[clean]; ok {
return nil
}
stale = append(stale, clean)
return nil
}); err != nil {
return err
}
for _, path := range stale {
rel, err := filepath.Rel(root, path)
if err != nil || rel == "." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) || filepath.IsAbs(rel) {
return fmt.Errorf("stale shard path escapes backup root: %s", path)
}
if err := os.Remove(path); err != nil {
return err
}
}
return nil
}
func EncryptShard(plaintext []byte, recipients []string) ([]byte, string, error) {
return encryptShard(plaintext, recipients)
}
func DecryptShard(ciphertext []byte, identityPath string) ([]byte, error) {
return decryptShard(ciphertext, identityPath)
}
func SHA256Hex(data []byte) string {
return sha256Hex(data)
}
func normalizedStrings(values []string) []string {
seen := map[string]struct{}{}
out := make([]string, 0, len(values))
for _, value := range values {
value = strings.TrimSpace(value)
if value == "" {
continue
}
if _, ok := seen[value]; ok {
continue
}
seen[value] = struct{}{}
out = append(out, value)
}
sort.Strings(out)
return out
}
func sameStrings(a, b []string) bool {
a, b = normalizedStrings(a), normalizedStrings(b)
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func sameCounts(a, b map[string]int) bool {
if len(a) != len(b) {
return false
}
for key, left := range a {
if b[key] != left {
return false
}
}
return true
}
func expandHome(p string) string {
if p == "~" {
if home, err := os.UserHomeDir(); err == nil {
return home
}
}
if after, ok := strings.CutPrefix(p, "~/"); ok {
if home, err := os.UserHomeDir(); err == nil {
return filepath.Join(home, after)
}
}
return p
}

54
backup/backup_test.go Normal file
View File

@ -0,0 +1,54 @@
package backup
import (
"context"
"os"
"path/filepath"
"testing"
)
type row struct {
ID string `json:"id"`
Body string `json:"body"`
}
func TestWriteReadEncryptedSnapshot(t *testing.T) {
dir := t.TempDir()
identity := filepath.Join(dir, "age.key")
recipient, err := EnsureIdentity(identity)
if err != nil {
t.Fatal(err)
}
cfg := Config{Repo: filepath.Join(dir, "repo"), Identity: identity, Recipients: []string{recipient}}
if err := os.MkdirAll(cfg.Repo, 0o700); err != nil {
t.Fatal(err)
}
manifest, err := WriteSnapshot(context.Background(), cfg, []Shard{
{Table: "messages", Path: "data/messages/2026/05.jsonl.gz.age", Rows: []row{{ID: "1", Body: "hello"}}},
}, Manifest{})
if err != nil {
t.Fatal(err)
}
if manifest.Counts["messages"] != 1 || len(manifest.Shards) != 1 {
t.Fatalf("unexpected manifest: %+v", manifest)
}
decoded, err := ReadSnapshot(cfg, manifest)
if err != nil {
t.Fatal(err)
}
var rows []row
if err := DecodeJSONL(decoded[0].Plaintext, &rows); err != nil {
t.Fatal(err)
}
if len(rows) != 1 || rows[0].Body != "hello" {
t.Fatalf("unexpected rows: %+v", rows)
}
}
func TestResolveShardPathRejectsEscapes(t *testing.T) {
for _, rel := range []string{"../x.age", "data/../x.age", "data/x.txt", "/data/x.age"} {
if _, err := ResolveShardPath(t.TempDir(), rel); err == nil {
t.Fatalf("expected error for %q", rel)
}
}
}

141
backup/crypto.go Normal file
View File

@ -0,0 +1,141 @@
package backup
import (
"bytes"
"compress/gzip"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"time"
"filippo.io/age"
)
func EnsureIdentity(path string) (string, error) {
path = expandHome(path)
if data, err := os.ReadFile(path); err == nil { // #nosec G304 -- path is the configured local age identity file.
identity, err := parseIdentity(data)
if err != nil {
return "", err
}
return identity.Recipient().String(), nil
} else if !os.IsNotExist(err) {
return "", err
}
identity, err := age.GenerateX25519Identity()
if err != nil {
return "", err
}
if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
return "", err
}
data := []byte(identity.String() + "\n")
if err := os.WriteFile(path, data, 0o600); err != nil {
return "", err
}
return identity.Recipient().String(), nil
}
func RecipientFromIdentity(path string) (string, error) {
data, err := os.ReadFile(expandHome(path))
if err != nil {
return "", err
}
identity, err := parseIdentity(data)
if err != nil {
return "", err
}
return identity.Recipient().String(), nil
}
func encryptShard(plaintext []byte, recipientStrings []string) ([]byte, string, error) {
recipients, err := parseRecipients(recipientStrings)
if err != nil {
return nil, "", err
}
var compressed bytes.Buffer
gz := gzip.NewWriter(&compressed)
gz.ModTime = time.Unix(0, 0).UTC()
_, _ = gz.Write(plaintext)
_ = gz.Close()
var encrypted bytes.Buffer
w, err := age.Encrypt(&encrypted, recipients...)
if err != nil {
return nil, "", err
}
_, _ = w.Write(compressed.Bytes())
if err := w.Close(); err != nil {
return nil, "", err
}
return encrypted.Bytes(), sha256Hex(plaintext), nil
}
func decryptShard(ciphertext []byte, identityPath string) ([]byte, error) {
data, err := os.ReadFile(expandHome(identityPath)) // #nosec G304 -- path is the configured local age identity file.
if err != nil {
return nil, err
}
identity, err := parseIdentity(data)
if err != nil {
return nil, err
}
r, err := age.Decrypt(bytes.NewReader(ciphertext), identity)
if err != nil {
return nil, err
}
gz, err := gzip.NewReader(r)
if err != nil {
return nil, err
}
defer func() { _ = gz.Close() }()
plaintext, err := io.ReadAll(gz)
if err != nil {
return nil, err
}
return plaintext, nil
}
func parseRecipients(values []string) ([]age.Recipient, error) {
var out []age.Recipient
for _, value := range values {
value = strings.TrimSpace(value)
if value == "" {
continue
}
recipient, err := age.ParseX25519Recipient(value)
if err != nil {
return nil, fmt.Errorf("parse age recipient: %w", err)
}
out = append(out, recipient)
}
if len(out) == 0 {
return nil, fmt.Errorf("at least one age recipient is required")
}
return out, nil
}
func parseIdentity(data []byte) (*age.X25519Identity, error) {
for _, line := range strings.Split(string(data), "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
identity, err := age.ParseX25519Identity(line)
if err != nil {
return nil, fmt.Errorf("parse age identity: %w", err)
}
return identity, nil
}
return nil, fmt.Errorf("age identity file is empty")
}
func sha256Hex(data []byte) string {
sum := sha256.Sum256(data)
return hex.EncodeToString(sum[:])
}

View File

@ -28,8 +28,9 @@ parsers, and product-specific ranking in the apps.
schema-version checks, transactions, safe identifier quoting, and generic
query helpers.
- Snapshot packing: manifest format, JSONL/Gzip shards, table filters,
import progress, sidecar registration, backward-compatible manifest reads,
and import callbacks.
per-file fingerprints, import progress, incremental import planning,
sidecar registration, backward-compatible manifest reads, and import
callbacks.
- Git mirror mechanics: clone/init, pull, origin management, path-scoped
commits, push retry behavior, and portable SQLite checkout cleanup.
- Sync freshness semantics: cursor/freshness records, stale checks, manifest

View File

@ -21,30 +21,30 @@ go test ./...
6. Tag the next semver release from `main`:
```bash
git tag -s v0.4.0
git tag -s v0.5.0
git push origin main
git push origin v0.4.0
git push origin v0.5.0
```
7. Prime and verify module proxy visibility:
```bash
GOPROXY=https://proxy.golang.org go list -m github.com/vincentkoc/crawlkit@v0.4.0
go list -m github.com/vincentkoc/crawlkit@v0.4.0
GOPROXY=https://proxy.golang.org go list -m github.com/openclaw/crawlkit@v0.5.0
go list -m github.com/openclaw/crawlkit@v0.5.0
```
8. Bump downstream apps to the new tag and commit their `go.mod`/`go.sum` updates:
```bash
go get github.com/vincentkoc/crawlkit@v0.4.0
go get github.com/openclaw/crawlkit@v0.5.0
go mod tidy
```
`pkg.go.dev` indexes public modules automatically after the tag is reachable.
Use a patch tag such as `v0.3.17` only for narrow bug fixes on the existing API.
Use a minor tag such as `v0.4.0` for broad shared TUI or crawler infrastructure
changes. This branch is a `v0.4.0`-shaped release.
Use a patch tag only for narrow bug fixes on the existing API. Use a minor tag
for broad crawler infrastructure changes. The module-path move needs a new tag
on `openclaw/crawlkit` before downstream apps can drop local `replace` lines.
## Versioning
@ -52,5 +52,5 @@ Keep `v0.x.y` while the downstream crawler rewires are still settling. If the
module ever reaches `v2`, Go requires the module path to become:
```text
github.com/vincentkoc/crawlkit/v2
github.com/openclaw/crawlkit/v2
```

91
embed/ollama.go Normal file
View File

@ -0,0 +1,91 @@
package embed
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
)
type ollamaProvider struct {
client *http.Client
baseURL string
model string
maxInputChars int
}
type ollamaEmbedRequest struct {
Model string `json:"model"`
Input []string `json:"input"`
}
type ollamaEmbedResponse struct {
Model string `json:"model"`
Embeddings [][]float32 `json:"embeddings"`
}
func newOllamaProvider(settings providerSettings) Provider {
return &ollamaProvider{
client: settings.HTTPClient,
baseURL: settings.BaseURL,
model: settings.Model,
maxInputChars: settings.MaxInputChars,
}
}
func (p *ollamaProvider) Embed(ctx context.Context, inputs []string) (EmbeddingBatch, error) {
if len(inputs) == 0 {
return EmbeddingBatch{Model: p.model}, nil
}
payload := ollamaEmbedRequest{
Model: p.model,
Input: trimInputs(inputs, p.maxInputChars),
}
var response ollamaEmbedResponse
if err := postJSON(ctx, p.client, p.baseURL+"/api/embed", "", payload, &response); err != nil {
return EmbeddingBatch{}, err
}
if len(response.Embeddings) != len(inputs) {
return EmbeddingBatch{}, fmt.Errorf("ollama embedding response returned %d vectors for %d inputs", len(response.Embeddings), len(inputs))
}
dimensions, err := inferDimensions(response.Embeddings)
if err != nil {
return EmbeddingBatch{}, err
}
model := response.Model
if model == "" {
model = p.model
}
return EmbeddingBatch{Model: model, Dimensions: dimensions, Vectors: response.Embeddings}, nil
}
func postJSON(ctx context.Context, client *http.Client, endpoint, apiKey string, payload any, target any) error {
body, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshal embedding request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return fmt.Errorf("build embedding request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
if apiKey != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("embedding request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
msg, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return &HTTPError{StatusCode: resp.StatusCode, Body: string(msg)}
}
if err := json.NewDecoder(resp.Body).Decode(target); err != nil {
return fmt.Errorf("decode embedding response: %w", err)
}
return nil
}

View File

@ -0,0 +1,82 @@
package embed
import (
"context"
"fmt"
"net/http"
)
type openAICompatibleProvider struct {
client *http.Client
baseURL string
apiKey string
model string
maxInputChars int
}
type openAIEmbeddingRequest struct {
Model string `json:"model"`
Input []string `json:"input"`
}
type openAIEmbeddingResponse struct {
Model string `json:"model"`
Data []openAIEmbeddingItem `json:"data"`
}
type openAIEmbeddingItem struct {
Index *int `json:"index"`
Embedding []float32 `json:"embedding"`
}
func newOpenAICompatibleProvider(settings providerSettings) Provider {
return &openAICompatibleProvider{
client: settings.HTTPClient,
baseURL: settings.BaseURL,
apiKey: settings.APIKey,
model: settings.Model,
maxInputChars: settings.MaxInputChars,
}
}
func (p *openAICompatibleProvider) Embed(ctx context.Context, inputs []string) (EmbeddingBatch, error) {
if len(inputs) == 0 {
return EmbeddingBatch{Model: p.model}, nil
}
payload := openAIEmbeddingRequest{
Model: p.model,
Input: trimInputs(inputs, p.maxInputChars),
}
var response openAIEmbeddingResponse
if err := postJSON(ctx, p.client, p.baseURL+"/embeddings", p.apiKey, payload, &response); err != nil {
return EmbeddingBatch{}, err
}
if len(response.Data) != len(inputs) {
return EmbeddingBatch{}, fmt.Errorf("openai-compatible embedding response returned %d vectors for %d inputs", len(response.Data), len(inputs))
}
vectors := make([][]float32, len(inputs))
seen := make([]bool, len(inputs))
for position, item := range response.Data {
index := position
if item.Index != nil {
index = *item.Index
}
if index < 0 || index >= len(inputs) {
return EmbeddingBatch{}, fmt.Errorf("openai-compatible embedding response index %d out of range", index)
}
if seen[index] {
return EmbeddingBatch{}, fmt.Errorf("openai-compatible embedding response duplicated index %d", index)
}
seen[index] = true
vectors[index] = item.Embedding
}
dimensions, err := inferDimensions(vectors)
if err != nil {
return EmbeddingBatch{}, err
}
model := response.Model
if model == "" {
model = p.model
}
return EmbeddingBatch{Model: model, Dimensions: dimensions, Vectors: vectors}, nil
}

317
embed/provider.go Normal file
View File

@ -0,0 +1,317 @@
package embed
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"os"
"strings"
"time"
)
const (
ProviderOpenAI = "openai"
ProviderOllama = "ollama"
ProviderLlamaCpp = "llamacpp"
ProviderOpenAICompatible = "openai_compatible"
DefaultOpenAIBaseURL = "https://api.openai.com/v1"
DefaultOllamaBaseURL = "http://127.0.0.1:11434"
DefaultLlamaCppBaseURL = "http://127.0.0.1:8080/v1"
DefaultOpenAIModel = "text-embedding-3-small"
DefaultLocalEmbeddingModel = "nomic-embed-text"
DefaultBatchSize = 64
DefaultMaxInputChars = 12000
DefaultRequestTimeout = 2 * time.Minute
DefaultProbeTimeout = 2 * time.Second
)
type Config struct {
Provider string
Model string
BaseURL string
APIKeyEnv string
RequestTimeout string
MaxInputChars int
}
type Provider interface {
Embed(ctx context.Context, inputs []string) (EmbeddingBatch, error)
}
type EmbeddingBatch struct {
Model string
Dimensions int
Vectors [][]float32
}
type HTTPError struct {
StatusCode int
Body string
}
func (e *HTTPError) Error() string {
return fmt.Sprintf("embedding request failed with HTTP %d: %s", e.StatusCode, e.Body)
}
func IsRateLimitError(err error) bool {
var httpErr *HTTPError
return errors.As(err, &httpErr) && httpErr.StatusCode == http.StatusTooManyRequests
}
type CheckResult struct {
Provider string
Model string
BaseURL string
Status string
Warning string
Probed bool
}
type Option func(*providerOptions)
type providerOptions struct {
httpClient *http.Client
timeoutOverride time.Duration
}
type providerSettings struct {
Name string
Model string
BaseURL string
APIKey string
MaxInputChars int
Timeout time.Duration
HTTPClient *http.Client
}
func WithHTTPClient(client *http.Client) Option {
return func(opts *providerOptions) {
opts.httpClient = client
}
}
func WithRequestTimeout(timeout time.Duration) Option {
return func(opts *providerOptions) {
opts.timeoutOverride = timeout
}
}
func NewProvider(cfg Config, opts ...Option) (Provider, error) {
settings, err := resolveProviderConfig(cfg, true, opts...)
if err != nil {
return nil, err
}
return newProvider(settings)
}
func CheckProvider(ctx context.Context, cfg Config) CheckResult {
settings, err := resolveProviderConfig(cfg, true, WithRequestTimeout(DefaultProbeTimeout))
if err != nil {
return CheckResult{
Provider: normalizedProviderName(cfg.Provider),
Model: strings.TrimSpace(cfg.Model),
BaseURL: strings.TrimSpace(cfg.BaseURL),
Status: "warning",
Warning: err.Error(),
}
}
result := CheckResult{
Provider: settings.Name,
Model: settings.Model,
BaseURL: settings.BaseURL,
Status: "ok",
}
if !shouldProbe(settings) {
return result
}
provider, err := newProvider(settings)
if err != nil {
result.Status = "warning"
result.Warning = err.Error()
return result
}
probeCtx, cancel := context.WithTimeout(ctx, DefaultProbeTimeout)
defer cancel()
if _, err := provider.Embed(probeCtx, []string{"crawlkit probe"}); err != nil {
result.Status = "warning"
result.Warning = err.Error()
return result
}
result.Probed = true
return result
}
func resolveProviderConfig(cfg Config, validateAPIKey bool, opts ...Option) (providerSettings, error) {
options := providerOptions{}
for _, opt := range opts {
opt(&options)
}
name := normalizedProviderName(cfg.Provider)
if name == "" {
name = ProviderOpenAI
}
model := strings.TrimSpace(cfg.Model)
if model == "" {
model = defaultModel(name)
}
baseURL := strings.TrimRight(strings.TrimSpace(cfg.BaseURL), "/")
if baseURL == "" {
switch name {
case ProviderOpenAI:
baseURL = DefaultOpenAIBaseURL
case ProviderOllama:
baseURL = DefaultOllamaBaseURL
case ProviderLlamaCpp:
baseURL = DefaultLlamaCppBaseURL
case ProviderOpenAICompatible:
return providerSettings{}, fmt.Errorf("embedding provider %q requires base_url", name)
}
}
timeout := DefaultRequestTimeout
if strings.TrimSpace(cfg.RequestTimeout) != "" {
parsed, err := time.ParseDuration(cfg.RequestTimeout)
if err != nil {
return providerSettings{}, fmt.Errorf("parse embeddings request_timeout: %w", err)
}
if parsed <= 0 {
return providerSettings{}, errors.New("embeddings request_timeout must be positive")
}
timeout = parsed
}
if options.timeoutOverride > 0 && options.timeoutOverride < timeout {
timeout = options.timeoutOverride
}
maxInputChars := cfg.MaxInputChars
if maxInputChars <= 0 {
maxInputChars = DefaultMaxInputChars
}
switch name {
case ProviderOpenAI, ProviderOllama, ProviderLlamaCpp, ProviderOpenAICompatible:
default:
return providerSettings{}, fmt.Errorf("unsupported embedding provider %q", name)
}
apiKey, err := resolveAPIKey(name, cfg.APIKeyEnv, validateAPIKey)
if err != nil {
return providerSettings{}, err
}
client := options.httpClient
if client == nil {
client = &http.Client{Timeout: timeout}
}
if _, err := url.ParseRequestURI(baseURL); err != nil {
return providerSettings{}, fmt.Errorf("invalid embeddings base_url %q: %w", baseURL, err)
}
return providerSettings{
Name: name,
Model: model,
BaseURL: baseURL,
APIKey: apiKey,
MaxInputChars: maxInputChars,
Timeout: timeout,
HTTPClient: client,
}, nil
}
func newProvider(settings providerSettings) (Provider, error) {
switch settings.Name {
case ProviderOllama:
return newOllamaProvider(settings), nil
case ProviderOpenAI, ProviderLlamaCpp, ProviderOpenAICompatible:
return newOpenAICompatibleProvider(settings), nil
default:
return nil, fmt.Errorf("unsupported embedding provider %q", settings.Name)
}
}
func resolveAPIKey(provider, apiKeyEnv string, validate bool) (string, error) {
envName := strings.TrimSpace(apiKeyEnv)
required := provider == ProviderOpenAI
if envName == "" {
if required {
envName = "OPENAI_API_KEY"
} else {
return "", nil
}
}
value := strings.TrimSpace(os.Getenv(envName))
if value == "" {
if required || validate {
return "", fmt.Errorf("embedding provider %q requires API key env %s", provider, envName)
}
return "", nil
}
return value, nil
}
func normalizedProviderName(provider string) string {
return strings.ToLower(strings.TrimSpace(provider))
}
func defaultModel(provider string) string {
switch provider {
case ProviderOllama, ProviderLlamaCpp:
return DefaultLocalEmbeddingModel
default:
return DefaultOpenAIModel
}
}
func shouldProbe(settings providerSettings) bool {
switch settings.Name {
case ProviderOllama, ProviderLlamaCpp:
return true
case ProviderOpenAICompatible:
return isLoopbackBaseURL(settings.BaseURL)
default:
return false
}
}
func isLoopbackBaseURL(rawURL string) bool {
parsed, err := url.Parse(rawURL)
if err != nil {
return false
}
host := parsed.Hostname()
if host == "localhost" {
return true
}
ip := net.ParseIP(host)
return ip != nil && ip.IsLoopback()
}
func trimInputs(inputs []string, maxChars int) []string {
if maxChars <= 0 {
maxChars = DefaultMaxInputChars
}
out := make([]string, len(inputs))
for i, input := range inputs {
runes := []rune(input)
if len(runes) > maxChars {
runes = runes[:maxChars]
}
out[i] = string(runes)
}
return out
}
func inferDimensions(vectors [][]float32) (int, error) {
dimensions := 0
for _, vector := range vectors {
if len(vector) == 0 {
return 0, errors.New("embedding response contained an empty vector")
}
if dimensions == 0 {
dimensions = len(vector)
continue
}
if len(vector) != dimensions {
return 0, fmt.Errorf("embedding response dimensions mismatch: got %d want %d", len(vector), dimensions)
}
}
return dimensions, nil
}

453
embed/provider_test.go Normal file
View File

@ -0,0 +1,453 @@
package embed
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"time"
)
func TestOllamaProviderEmbeds(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/api/embed", r.URL.Path)
assert.Equal(t, http.MethodPost, r.Method)
var req ollamaEmbedRequest
assert.NoError(t, json.NewDecoder(r.Body).Decode(&req))
assert.Equal(t, "nomic-embed-text", req.Model)
assert.Equal(t, []string{"abcd", "xy"}, req.Input)
_, _ = w.Write([]byte(`{"model":"nomic-embed-text","embeddings":[[1,2,3],[4,5,6]]}`))
}))
defer server.Close()
provider, err := NewProvider(Config{
Provider: ProviderOllama,
Model: "nomic-embed-text",
BaseURL: server.URL,
MaxInputChars: 4,
RequestTimeout: "5s",
})
require.NoError(t, err)
batch, err := provider.Embed(context.Background(), []string{"abcdef", "xy"})
require.NoError(t, err)
require.Equal(t, "nomic-embed-text", batch.Model)
require.Equal(t, 3, batch.Dimensions)
require.Equal(t, [][]float32{{1, 2, 3}, {4, 5, 6}}, batch.Vectors)
}
type assertAPI struct{}
type requireAPI struct{}
var assert assertAPI
var require requireAPI
func (assertAPI) Equal(t *testing.T, want, got any) bool {
t.Helper()
if !reflect.DeepEqual(want, got) {
t.Errorf("not equal:\nwant: %#v\n got: %#v", want, got)
return false
}
return true
}
func (assertAPI) NoError(t *testing.T, err error) bool {
t.Helper()
if err != nil {
t.Errorf("unexpected error: %v", err)
return false
}
return true
}
func (requireAPI) NoError(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
func (requireAPI) Equal(t *testing.T, want, got any) {
t.Helper()
if !reflect.DeepEqual(want, got) {
t.Fatalf("not equal:\nwant: %#v\n got: %#v", want, got)
}
}
func (requireAPI) Same(t *testing.T, want, got any) {
t.Helper()
if !reflect.ValueOf(want).IsValid() || !reflect.ValueOf(got).IsValid() ||
reflect.ValueOf(want).Pointer() != reflect.ValueOf(got).Pointer() {
t.Fatalf("not same:\nwant: %#v\n got: %#v", want, got)
}
}
func (requireAPI) True(t *testing.T, value bool) {
t.Helper()
if !value {
t.Fatal("expected true")
}
}
func (requireAPI) False(t *testing.T, value bool) {
t.Helper()
if value {
t.Fatal("expected false")
}
}
func (requireAPI) Empty(t *testing.T, value string) {
t.Helper()
if value != "" {
t.Fatalf("expected empty string, got %q", value)
}
}
func (requireAPI) Contains(t *testing.T, value, needle string) {
t.Helper()
if !strings.Contains(value, needle) {
t.Fatalf("expected %q to contain %q", value, needle)
}
}
func (requireAPI) ErrorContains(t *testing.T, err error, needle string) {
t.Helper()
if err == nil {
t.Fatalf("expected error containing %q, got nil", needle)
}
if !strings.Contains(err.Error(), needle) {
t.Fatalf("expected error containing %q, got %q", needle, err.Error())
}
}
func TestOpenAICompatibleProviderEmbedsAndUsesAuth(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/embeddings", r.URL.Path)
assert.Equal(t, "Bearer secret", r.Header.Get("Authorization"))
var req openAIEmbeddingRequest
assert.NoError(t, json.NewDecoder(r.Body).Decode(&req))
assert.Equal(t, "local-model", req.Model)
assert.Equal(t, []string{"one", "two"}, req.Input)
_, _ = w.Write([]byte(`{
"model":"local-model",
"data":[
{"index":1,"embedding":[3,4]},
{"index":0,"embedding":[1,2]}
]
}`))
}))
defer server.Close()
t.Setenv("CRAWLKIT_EMBED_KEY", "secret")
provider, err := NewProvider(Config{
Provider: ProviderOpenAICompatible,
Model: "local-model",
BaseURL: server.URL,
APIKeyEnv: "CRAWLKIT_EMBED_KEY",
RequestTimeout: "5s",
})
require.NoError(t, err)
batch, err := provider.Embed(context.Background(), []string{"one", "two"})
require.NoError(t, err)
require.Equal(t, "local-model", batch.Model)
require.Equal(t, 2, batch.Dimensions)
require.Equal(t, [][]float32{{1, 2}, {3, 4}}, batch.Vectors)
}
func TestProviderFactoryDefaultsAndValidation(t *testing.T) {
t.Setenv("OPENAI_API_KEY", "openai-secret")
openAI, err := resolveProviderConfig(Config{
Provider: ProviderOpenAI,
RequestTimeout: "5s",
}, true)
require.NoError(t, err)
require.Equal(t, DefaultOpenAIBaseURL, openAI.BaseURL)
require.Equal(t, DefaultOpenAIModel, openAI.Model)
require.Equal(t, "openai-secret", openAI.APIKey)
ollama, err := resolveProviderConfig(Config{
Provider: ProviderOllama,
RequestTimeout: "5s",
}, true)
require.NoError(t, err)
require.Equal(t, DefaultOllamaBaseURL, ollama.BaseURL)
require.Equal(t, DefaultLocalEmbeddingModel, ollama.Model)
llamaCpp, err := resolveProviderConfig(Config{
Provider: ProviderLlamaCpp,
RequestTimeout: "5s",
}, true)
require.NoError(t, err)
require.Equal(t, DefaultLlamaCppBaseURL, llamaCpp.BaseURL)
_, err = resolveProviderConfig(Config{
Provider: ProviderOpenAICompatible,
RequestTimeout: "5s",
}, true)
require.ErrorContains(t, err, "requires base_url")
}
func TestProviderFactoryRequiresOpenAIAPIKey(t *testing.T) {
t.Setenv("OPENAI_API_KEY", "")
_, err := NewProvider(Config{
Provider: ProviderOpenAI,
RequestTimeout: "5s",
})
require.ErrorContains(t, err, "requires API key env OPENAI_API_KEY")
}
func TestProviderFactoryReportsUnsupportedProviderBeforeAPIKey(t *testing.T) {
t.Setenv("MISSING_EMBED_KEY", "")
_, err := NewProvider(Config{
Provider: "bogus",
APIKeyEnv: "MISSING_EMBED_KEY",
RequestTimeout: "5s",
})
require.ErrorContains(t, err, "unsupported embedding provider \"bogus\"")
}
func TestCheckProviderProbesLocalProvider(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/api/embed", r.URL.Path)
_, _ = w.Write([]byte(`{"model":"nomic-embed-text","embeddings":[[1,2]]}`))
}))
defer server.Close()
result := CheckProvider(context.Background(), Config{
Provider: ProviderOllama,
Model: "nomic-embed-text",
BaseURL: server.URL,
RequestTimeout: "5s",
})
require.Equal(t, "ok", result.Status)
require.True(t, result.Probed)
require.Empty(t, result.Warning)
require.Equal(t, server.URL, result.BaseURL)
}
func TestCheckProviderWarnsOnLocalProbeFailure(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "not ready", http.StatusServiceUnavailable)
}))
defer server.Close()
result := CheckProvider(context.Background(), Config{
Provider: ProviderOllama,
Model: "nomic-embed-text",
BaseURL: server.URL,
RequestTimeout: "5s",
})
require.Equal(t, "warning", result.Status)
require.Contains(t, result.Warning, "HTTP 503")
require.False(t, result.Probed)
}
func TestProviderExposesRateLimitErrors(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "rate limited", http.StatusTooManyRequests)
}))
defer server.Close()
provider, err := NewProvider(Config{
Provider: ProviderOpenAICompatible,
Model: "local-model",
BaseURL: server.URL,
RequestTimeout: "5s",
})
require.NoError(t, err)
_, err = provider.Embed(context.Background(), []string{"one"})
require.ErrorContains(t, err, "HTTP 429")
require.True(t, IsRateLimitError(err))
}
func TestProviderRejectsInvalidResponses(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(`{"data":[{"index":0,"embedding":[1]},{"index":1,"embedding":[2,3]}]}`))
}))
defer server.Close()
provider, err := NewProvider(Config{
Provider: ProviderOpenAICompatible,
Model: "local-model",
BaseURL: server.URL,
RequestTimeout: "5s",
})
require.NoError(t, err)
_, err = provider.Embed(context.Background(), []string{"one", "two"})
require.ErrorContains(t, err, "dimensions mismatch")
}
func TestEmbeddingProvidersHandleEmptyInputsAndIndexErrors(t *testing.T) {
t.Parallel()
settings := providerSettings{
Name: ProviderOllama,
Model: "model",
BaseURL: "http://127.0.0.1:1",
MaxInputChars: 10,
HTTPClient: http.DefaultClient,
}
ollama := newOllamaProvider(settings)
batch, err := ollama.Embed(context.Background(), nil)
require.NoError(t, err)
require.Equal(t, "model", batch.Model)
settings.Name = ProviderOpenAICompatible
openai := newOpenAICompatibleProvider(settings)
batch, err = openai.Embed(context.Background(), nil)
require.NoError(t, err)
require.Equal(t, "model", batch.Model)
tests := []struct {
name string
body string
inputs []string
want string
}{
{name: "count", body: `{"data":[]}`, inputs: []string{"one"}, want: "returned 0 vectors for 1 inputs"},
{name: "range", body: `{"data":[{"index":2,"embedding":[1]}]}`, inputs: []string{"one"}, want: "index 2 out of range"},
{name: "duplicate", body: `{"data":[{"index":0,"embedding":[1]},{"index":0,"embedding":[2]}]}`, inputs: []string{"one", "two"}, want: "duplicated index 0"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(tc.body))
}))
defer server.Close()
provider, err := NewProvider(Config{
Provider: ProviderOpenAICompatible,
Model: "model",
BaseURL: server.URL,
RequestTimeout: "5s",
})
require.NoError(t, err)
_, err = provider.Embed(context.Background(), tc.inputs)
require.ErrorContains(t, err, tc.want)
})
}
}
func TestProviderOptionsAndProbeDecisions(t *testing.T) {
t.Parallel()
client := &http.Client{Timeout: time.Second}
settings, err := resolveProviderConfig(Config{
Provider: ProviderOllama,
BaseURL: "http://127.0.0.1:11434/",
RequestTimeout: "30s",
}, true, WithHTTPClient(client), WithRequestTimeout(50*time.Millisecond))
require.NoError(t, err)
require.Same(t, client, settings.HTTPClient)
require.Equal(t, 50*time.Millisecond, settings.Timeout)
require.Equal(t, "http://127.0.0.1:11434", settings.BaseURL)
require.True(t, shouldProbe(settings))
require.True(t, isLoopbackBaseURL("http://localhost:8080/v1"))
require.True(t, isLoopbackBaseURL("http://[::1]:8080/v1"))
require.False(t, isLoopbackBaseURL("https://api.example.com/v1"))
require.False(t, isLoopbackBaseURL("://bad"))
require.False(t, shouldProbe(providerSettings{Name: ProviderOpenAI}))
require.True(t, shouldProbe(providerSettings{Name: ProviderOpenAICompatible, BaseURL: "http://localhost:8080/v1"}))
require.False(t, shouldProbe(providerSettings{Name: ProviderOpenAICompatible, BaseURL: "https://api.example.com/v1"}))
}
func TestProviderValidationEdges(t *testing.T) {
t.Parallel()
_, err := resolveProviderConfig(Config{
Provider: ProviderOllama,
RequestTimeout: "not-a-duration",
}, true)
require.ErrorContains(t, err, "parse embeddings request_timeout")
_, err = resolveProviderConfig(Config{
Provider: ProviderOllama,
RequestTimeout: "0s",
}, true)
require.ErrorContains(t, err, "must be positive")
_, err = resolveProviderConfig(Config{
Provider: ProviderOllama,
BaseURL: "://bad",
}, true)
require.ErrorContains(t, err, "invalid embeddings base_url")
key, err := resolveAPIKey(ProviderOpenAICompatible, "MISSING_EMBED_KEY", false)
require.NoError(t, err)
require.Empty(t, key)
_, err = newProvider(providerSettings{Name: "bogus"})
require.ErrorContains(t, err, "unsupported embedding provider")
require.Equal(t, []string{"abc"}, trimInputs([]string{"abc"}, 0))
_, err = inferDimensions([][]float32{{}})
require.ErrorContains(t, err, "empty vector")
}
func TestOllamaProviderResponseEdges(t *testing.T) {
t.Parallel()
countServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/api/embed", r.URL.Path)
_, _ = w.Write([]byte(`{"embeddings":[]}`))
}))
defer countServer.Close()
provider := newOllamaProvider(providerSettings{
HTTPClient: countServer.Client(),
BaseURL: countServer.URL,
Model: "fallback-model",
MaxInputChars: 10,
})
_, err := provider.Embed(context.Background(), []string{"one"})
require.ErrorContains(t, err, "returned 0 vectors for 1 inputs")
modelServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/api/embed", r.URL.Path)
_, _ = w.Write([]byte(`{"embeddings":[[1,2]]}`))
}))
defer modelServer.Close()
provider = newOllamaProvider(providerSettings{
HTTPClient: modelServer.Client(),
BaseURL: modelServer.URL,
Model: "fallback-model",
MaxInputChars: 10,
})
batch, err := provider.Embed(context.Background(), []string{"one"})
require.NoError(t, err)
require.Equal(t, "fallback-model", batch.Model)
}
func TestCheckProviderSkipsRemoteCompatibleProbe(t *testing.T) {
t.Parallel()
result := CheckProvider(context.Background(), Config{
Provider: ProviderOpenAICompatible,
Model: "remote-model",
BaseURL: "https://api.example.com/v1",
RequestTimeout: "5s",
})
require.Equal(t, "ok", result.Status)
require.False(t, result.Probed)
require.Empty(t, result.Warning)
}

22
go.mod
View File

@ -1,31 +1,32 @@
module github.com/vincentkoc/crawlkit
module github.com/openclaw/crawlkit
go 1.26.2
require (
filippo.io/age v1.3.1
github.com/charmbracelet/bubbles v1.0.0
github.com/charmbracelet/bubbletea v1.3.10
github.com/charmbracelet/lipgloss v1.1.0
github.com/charmbracelet/x/ansi v0.11.6
github.com/charmbracelet/x/ansi v0.11.7
github.com/charmbracelet/x/term v0.2.2
github.com/mattn/go-isatty v0.0.20
github.com/pelletier/go-toml/v2 v2.3.0
github.com/mattn/go-isatty v0.0.22
github.com/pelletier/go-toml/v2 v2.3.1
modernc.org/sqlite v1.50.0
)
require (
filippo.io/hpke v0.4.0 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/charmbracelet/colorprofile v0.4.1 // indirect
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
github.com/clipperhouse/displaywidth v0.9.0 // indirect
github.com/clipperhouse/stringish v0.1.1 // indirect
github.com/clipperhouse/uax29/v2 v2.5.0 // indirect
github.com/clipperhouse/displaywidth v0.11.0 // indirect
github.com/clipperhouse/uax29/v2 v2.7.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
github.com/lucasb-eyer/go-colorful v1.4.0 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.19 // indirect
github.com/mattn/go-runewidth v0.0.23 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect
@ -33,8 +34,9 @@ require (
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
golang.org/x/crypto v0.45.0 // indirect
golang.org/x/sys v0.42.0 // indirect
golang.org/x/text v0.3.8 // indirect
golang.org/x/text v0.31.0 // indirect
modernc.org/libc v1.72.0 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect

43
go.sum
View File

@ -1,3 +1,9 @@
c2sp.org/CCTV/age v0.0.0-20251208015420-e9274a7bdbfd h1:ZLsPO6WdZ5zatV4UfVpr7oAwLGRZ+sebTUruuM4Ra3M=
c2sp.org/CCTV/age v0.0.0-20251208015420-e9274a7bdbfd/go.mod h1:SrHC2C7r5GkDk8R+NFVzYy/sdj0Ypg9htaPXQq5Cqeo=
filippo.io/age v1.3.1 h1:hbzdQOJkuaMEpRCLSN1/C5DX74RPcNCk6oqhKMXmZi0=
filippo.io/age v1.3.1/go.mod h1:EZorDTYUxt836i3zdori5IJX/v2Lj6kWFU0cfh6C0D4=
filippo.io/hpke v0.4.0 h1:p575VVQ6ted4pL+it6M00V/f2qTZITO0zgmdKCkd5+A=
filippo.io/hpke v0.4.0/go.mod h1:EmAN849/P3qdeK+PCMkDpDm83vRHM5cDipBJ8xbQLVY=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc=
@ -8,18 +14,16 @@ github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
github.com/charmbracelet/x/ansi v0.11.7 h1:kzv1kJvjg2S3r9KHo8hDdHFQLEqn4RBCb39dAYC84jI=
github.com/charmbracelet/x/ansi v0.11.7/go.mod h1:9qGpnAVYz+8ACONkZBUWPtL7lulP9No6p1epAihUZwQ=
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA=
github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA=
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/clipperhouse/displaywidth v0.11.0 h1:lBc6kY44VFw+TDx4I8opi/EtL9m20WSEFgwIwO+UVM8=
github.com/clipperhouse/displaywidth v0.11.0/go.mod h1:bkrFNkf81G8HyVqmKGxsPufD3JhNl3dSqnGhOoSD/o0=
github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJv2v7Vk=
github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
@ -30,14 +34,14 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/lucasb-eyer/go-colorful v1.4.0 h1:UtrWVfLdarDgc44HcS7pYloGHJUjHV/4FwW4TvVgFr4=
github.com/lucasb-eyer/go-colorful v1.4.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-isatty v0.0.22 h1:j8l17JJ9i6VGPUFUYoTUKPSgKe/83EYU2zBC7YNKMw4=
github.com/mattn/go-isatty v0.0.22/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4=
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3RybWcw=
github.com/mattn/go-runewidth v0.0.23/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
@ -46,14 +50,16 @@ github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/pelletier/go-toml/v2 v2.3.0 h1:k59bC/lIZREW0/iVaQR8nDHxVq8OVlIzYCOJf421CaM=
github.com/pelletier/go-toml/v2 v2.3.0/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pelletier/go-toml/v2 v2.3.1 h1:MYEvvGnQjeNkRF1qUuGolNtNExTDwct51yp7olPtrEc=
github.com/pelletier/go-toml/v2 v2.3.1/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
@ -61,11 +67,10 @@ golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY=
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
modernc.org/cc/v4 v4.27.3 h1:uNCgn37E5U09mTv1XgskEVUJ8ADKpmFMPxzGJ0TSo+U=

View File

@ -4,10 +4,13 @@ import (
"bufio"
"compress/gzip"
"context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"hash"
"io"
"os"
"path/filepath"
@ -15,7 +18,7 @@ import (
"strings"
"time"
"github.com/vincentkoc/crawlkit/store"
"github.com/openclaw/crawlkit/store"
)
const ManifestName = "manifest.json"
@ -38,6 +41,7 @@ type ImportOptions struct {
DeleteTables []string
DeleteTable DeleteFunc
Filter RowFilter
ImportRow RowImportFunc
Progress func(ImportProgress)
BeforeImport func(context.Context, *sql.Tx) error
AfterImport func(context.Context, *sql.Tx) error
@ -45,6 +49,8 @@ type ImportOptions struct {
type RowFilter func(table string, row map[string]any) (bool, error)
type RowImportFunc func(ctx context.Context, tx *sql.Tx, table string, row map[string]any) error
type DeleteFunc func(ctx context.Context, tx *sql.Tx, table string) error
type ImportProgress struct {
@ -72,15 +78,58 @@ type Manifest struct {
}
type TableManifest struct {
Name string `json:"name"`
File string `json:"file,omitempty"`
Files []string `json:"files"`
Columns []string `json:"columns"`
Rows int `json:"rows"`
Name string `json:"name"`
File string `json:"file,omitempty"`
Files []string `json:"files"`
FileManifests []FileManifest `json:"file_manifests,omitempty"`
Columns []string `json:"columns"`
Rows int `json:"rows"`
}
type FileManifest struct {
Path string `json:"path"`
Rows int `json:"rows"`
Size int64 `json:"size,omitempty"`
SHA256 string `json:"sha256,omitempty"`
}
var ErrNoManifest = errors.New("pack manifest not found")
type TableImportMode string
const (
TableImportSkip TableImportMode = "skip"
TableImportReplace TableImportMode = "replace"
TableImportFiles TableImportMode = "files"
)
type ImportPlan struct {
Full bool
Reason string
Tables []TableImportPlan
}
type TableImportPlan struct {
Table TableManifest
Mode TableImportMode
Files []FileManifest
Reason string
}
type IncrementalImportOptions struct {
DB *sql.DB
RootDir string
Previous Manifest
Current Manifest
Plan ImportPlan
DeleteTable DeleteFunc
Filter RowFilter
ImportRow RowImportFunc
Progress func(ImportProgress)
BeforeImport func(context.Context, *sql.Tx) error
AfterImport func(context.Context, *sql.Tx) error
}
func Export(ctx context.Context, opts ExportOptions) (Manifest, error) {
if opts.DB == nil {
return Manifest{}, errors.New("db is required")
@ -170,7 +219,7 @@ func Import(ctx context.Context, opts ImportOptions) (Manifest, error) {
}
}
for _, table := range manifest.Tables {
rows, err := importTable(ctx, tx, opts.RootDir, table, opts.Filter, opts.Progress)
rows, err := importTable(ctx, tx, opts.RootDir, table, opts.Filter, opts.ImportRow, opts.Progress)
if err != nil {
return Manifest{}, err
}
@ -188,6 +237,130 @@ func Import(ctx context.Context, opts ImportOptions) (Manifest, error) {
return manifest, nil
}
func PlanIncrementalImport(previous, current Manifest) ImportPlan {
if current.Version != previous.Version {
return ImportPlan{Full: true, Reason: "manifest version changed"}
}
previousTables := make(map[string]TableManifest, len(previous.Tables))
for _, table := range previous.Tables {
previousTables[table.Name] = table
}
currentTables := make(map[string]TableManifest, len(current.Tables))
for _, table := range current.Tables {
currentTables[table.Name] = table
}
for name := range previousTables {
if _, ok := currentTables[name]; !ok {
return ImportPlan{Full: true, Reason: "table removed: " + name}
}
}
plan := ImportPlan{}
for _, table := range current.Tables {
previousTable, ok := previousTables[table.Name]
if !ok {
plan.Tables = append(plan.Tables, TableImportPlan{
Table: table,
Mode: TableImportReplace,
Files: tableFileManifests(table),
Reason: "new table",
})
continue
}
tablePlan := planTableIncrement(previousTable, table)
plan.Tables = append(plan.Tables, tablePlan)
}
return plan
}
func (p ImportPlan) Changed() bool {
if p.Full {
return true
}
for _, table := range p.Tables {
if table.Mode != TableImportSkip {
return true
}
}
return false
}
func ImportIncremental(ctx context.Context, opts IncrementalImportOptions) (Manifest, ImportPlan, error) {
if opts.DB == nil {
return Manifest{}, ImportPlan{}, errors.New("db is required")
}
current := opts.Current
var err error
if len(current.Tables) == 0 {
current, err = ReadManifest(opts.RootDir)
if err != nil {
return Manifest{}, ImportPlan{}, err
}
}
plan := opts.Plan
if len(plan.Tables) == 0 && !plan.Full && plan.Reason == "" {
plan = PlanIncrementalImport(opts.Previous, current)
}
if plan.Full {
return Manifest{}, plan, errors.New("incremental import requires a non-full plan: " + plan.Reason)
}
if !plan.Changed() {
return current, plan, nil
}
tx, err := opts.DB.BeginTx(ctx, nil)
if err != nil {
return Manifest{}, plan, fmt.Errorf("begin incremental import tx: %w", err)
}
committed := false
defer func() {
if !committed {
_ = tx.Rollback()
}
}()
if opts.BeforeImport != nil {
if err := opts.BeforeImport(ctx, tx); err != nil {
return Manifest{}, plan, err
}
}
for _, tablePlan := range plan.Tables {
switch tablePlan.Mode {
case TableImportSkip:
continue
case TableImportReplace:
if err := deleteImportTable(ctx, tx, tablePlan.Table.Name, opts.DeleteTable); err != nil {
return Manifest{}, plan, err
}
rows, err := importTable(ctx, tx, opts.RootDir, tablePlan.Table, opts.Filter, opts.ImportRow, opts.Progress)
if err != nil {
return Manifest{}, plan, err
}
reportImportProgress(opts.Progress, ImportProgress{Phase: "table_done", Table: tablePlan.Table.Name, Rows: rows, TotalRows: tablePlan.Table.Rows})
case TableImportFiles:
table := tablePlan.Table
table.File = ""
table.Files = fileManifestPaths(tablePlan.Files)
table.FileManifests = tablePlan.Files
table.Rows = fileManifestRows(tablePlan.Files)
rows, err := importTable(ctx, tx, opts.RootDir, table, opts.Filter, opts.ImportRow, opts.Progress)
if err != nil {
return Manifest{}, plan, err
}
reportImportProgress(opts.Progress, ImportProgress{Phase: "table_done", Table: tablePlan.Table.Name, Rows: rows, TotalRows: table.Rows})
default:
return Manifest{}, plan, fmt.Errorf("unknown table import mode %q for %s", tablePlan.Mode, tablePlan.Table.Name)
}
}
if opts.AfterImport != nil {
if err := opts.AfterImport(ctx, tx); err != nil {
return Manifest{}, plan, err
}
}
if err := tx.Commit(); err != nil {
return Manifest{}, plan, fmt.Errorf("commit incremental import tx: %w", err)
}
committed = true
return current, plan, nil
}
func ReadManifest(rootDir string) (Manifest, error) {
data, err := os.ReadFile(filepath.Join(rootDir, ManifestName))
if errors.Is(err, os.ErrNotExist) {
@ -278,10 +451,10 @@ func exportTable(ctx context.Context, db *sql.DB, rootDir, table string, maxShar
if err := writer.close(); err != nil {
return TableManifest{}, err
}
return TableManifest{Name: table, Files: writer.files, Columns: cols, Rows: count}, nil
return TableManifest{Name: table, Files: writer.files, FileManifests: writer.fileManifests, Columns: cols, Rows: count}, nil
}
func importTable(ctx context.Context, tx *sql.Tx, rootDir string, table TableManifest, filter RowFilter, progress func(ImportProgress)) (int, error) {
func importTable(ctx context.Context, tx *sql.Tx, rootDir string, table TableManifest, filter RowFilter, importRow RowImportFunc, progress func(ImportProgress)) (int, error) {
files := table.Files
if len(files) == 0 && strings.TrimSpace(table.File) != "" {
files = []string{table.File}
@ -299,7 +472,7 @@ func importTable(ctx context.Context, tx *sql.Tx, rootDir string, table TableMan
}
fileProgress := ImportProgress{Phase: "file_start", Table: table.Name, File: rel, FileIndex: index + 1, FileCount: len(files), TotalRows: table.Rows}
reportImportProgress(progress, fileProgress)
rows, err := importJSONLGzip(ctx, tx, file, table.Name, filter)
rows, err := importJSONLGzip(ctx, tx, file, table.Name, filter, importRow)
if err != nil {
_ = file.Close()
return totalRows, err
@ -315,7 +488,7 @@ func importTable(ctx context.Context, tx *sql.Tx, rootDir string, table TableMan
return totalRows, nil
}
func importJSONLGzip(ctx context.Context, tx *sql.Tx, reader io.Reader, table string, filter RowFilter) (int, error) {
func importJSONLGzip(ctx context.Context, tx *sql.Tx, reader io.Reader, table string, filter RowFilter, importRow RowImportFunc) (int, error) {
gz, err := gzip.NewReader(reader)
if err != nil {
return 0, fmt.Errorf("open gzip for %s: %w", table, err)
@ -341,7 +514,11 @@ func importJSONLGzip(ctx context.Context, tx *sql.Tx, reader io.Reader, table st
continue
}
}
if err := insertRow(ctx, tx, table, row); err != nil {
importFunc := importRow
if importFunc == nil {
importFunc = insertRow
}
if err := importFunc(ctx, tx, table, row); err != nil {
return rows, err
}
rows++
@ -358,6 +535,16 @@ func reportImportProgress(progress func(ImportProgress), event ImportProgress) {
}
}
func deleteImportTable(ctx context.Context, tx *sql.Tx, table string, deleteTable DeleteFunc) error {
if deleteTable != nil {
return deleteTable(ctx, tx, table)
}
if _, err := tx.ExecContext(ctx, "delete from "+store.QuoteIdent(table)); err != nil {
return fmt.Errorf("clear table %s: %w", table, err)
}
return nil
}
func insertRow(ctx context.Context, tx *sql.Tx, table string, row map[string]any) error {
cols := make([]string, 0, len(row))
for col := range row {
@ -391,8 +578,11 @@ type shardWriter struct {
nextShard int
rowsInShard int
files []string
fileManifests []FileManifest
currentRel string
file *os.File
counter *countingWriter
hasher hash.Hash
gz *gzip.Writer
}
@ -415,8 +605,10 @@ func (w *shardWriter) open() error {
w.nextShard++
w.rowsInShard = 0
w.files = append(w.files, rel)
w.currentRel = rel
w.file = file
w.counter = &countingWriter{w: file}
w.hasher = sha256.New()
w.counter = &countingWriter{w: io.MultiWriter(file, w.hasher)}
w.gz = gzip.NewWriter(w.counter)
return nil
}
@ -459,6 +651,17 @@ func (w *shardWriter) close() error {
if closeErr != nil {
return fmt.Errorf("close shard: %w", closeErr)
}
if w.currentRel != "" && w.counter != nil && w.hasher != nil {
w.fileManifests = append(w.fileManifests, FileManifest{
Path: w.currentRel,
Rows: w.rowsInShard,
Size: w.counter.n,
SHA256: hex.EncodeToString(w.hasher.Sum(nil)),
})
}
w.currentRel = ""
w.counter = nil
w.hasher = nil
return nil
}
@ -481,3 +684,119 @@ func exportValue(value any) any {
return v
}
}
func planTableIncrement(previous, current TableManifest) TableImportPlan {
if !sameStrings(previous.Columns, current.Columns) {
return TableImportPlan{Table: current, Mode: TableImportReplace, Files: tableFileManifests(current), Reason: "columns changed"}
}
previousFiles := tableFileManifests(previous)
currentFiles := tableFileManifests(current)
if len(previousFiles) == 0 && len(currentFiles) == 0 {
return TableImportPlan{Table: current, Mode: TableImportSkip, Reason: "unchanged"}
}
if !allFilesHaveFingerprints(previousFiles) || !allFilesHaveFingerprints(currentFiles) {
return TableImportPlan{Table: current, Mode: TableImportReplace, Files: currentFiles, Reason: "missing file fingerprints"}
}
if sameFileManifests(previousFiles, currentFiles) {
return TableImportPlan{Table: current, Mode: TableImportSkip, Reason: "unchanged"}
}
if len(currentFiles) < len(previousFiles) {
return TableImportPlan{Table: current, Mode: TableImportReplace, Files: currentFiles, Reason: "files removed"}
}
for i := 0; i < len(previousFiles)-1; i++ {
if !sameFileManifest(previousFiles[i], currentFiles[i]) {
return TableImportPlan{Table: current, Mode: TableImportReplace, Files: currentFiles, Reason: "non-tail file changed"}
}
}
changed := make([]FileManifest, 0, len(currentFiles)-len(previousFiles)+1)
if len(previousFiles) > 0 {
oldTail := previousFiles[len(previousFiles)-1]
newTail := currentFiles[len(previousFiles)-1]
if oldTail.Path != newTail.Path {
return TableImportPlan{Table: current, Mode: TableImportReplace, Files: currentFiles, Reason: "tail path changed"}
}
if !sameFileManifest(oldTail, newTail) {
if newTail.Rows < oldTail.Rows {
return TableImportPlan{Table: current, Mode: TableImportReplace, Files: currentFiles, Reason: "tail rows removed"}
}
changed = append(changed, newTail)
}
}
for i := len(previousFiles); i < len(currentFiles); i++ {
changed = append(changed, currentFiles[i])
}
if len(changed) == 0 {
return TableImportPlan{Table: current, Mode: TableImportSkip, Reason: "unchanged"}
}
return TableImportPlan{Table: current, Mode: TableImportFiles, Files: changed, Reason: "tail files changed"}
}
func tableFileManifests(table TableManifest) []FileManifest {
if len(table.FileManifests) > 0 {
out := make([]FileManifest, len(table.FileManifests))
copy(out, table.FileManifests)
return out
}
files := table.Files
if len(files) == 0 && strings.TrimSpace(table.File) != "" {
files = []string{table.File}
}
out := make([]FileManifest, 0, len(files))
for _, file := range files {
out = append(out, FileManifest{Path: file})
}
return out
}
func allFilesHaveFingerprints(files []FileManifest) bool {
for _, file := range files {
if file.Path == "" || file.SHA256 == "" {
return false
}
}
return true
}
func sameFileManifests(a, b []FileManifest) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if !sameFileManifest(a[i], b[i]) {
return false
}
}
return true
}
func sameFileManifest(a, b FileManifest) bool {
return a.Path == b.Path && a.Rows == b.Rows && a.Size == b.Size && a.SHA256 == b.SHA256
}
func fileManifestPaths(files []FileManifest) []string {
paths := make([]string, 0, len(files))
for _, file := range files {
paths = append(paths, file.Path)
}
return paths
}
func fileManifestRows(files []FileManifest) int {
rows := 0
for _, file := range files {
rows += file.Rows
}
return rows
}
func sameStrings(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}

View File

@ -10,7 +10,7 @@ import (
"testing"
"time"
"github.com/vincentkoc/crawlkit/store"
"github.com/openclaw/crawlkit/store"
)
func TestExportImportTablesWithFilter(t *testing.T) {
@ -95,6 +95,158 @@ func TestExportRotatesShards(t *testing.T) {
if len(manifest.Tables[0].Files) < 2 {
t.Fatalf("expected multiple shards, got %+v", manifest.Tables[0].Files)
}
if len(manifest.Tables[0].FileManifests) != len(manifest.Tables[0].Files) {
t.Fatalf("file manifests = %+v, files = %+v", manifest.Tables[0].FileManifests, manifest.Tables[0].Files)
}
for _, file := range manifest.Tables[0].FileManifests {
if file.Path == "" || file.Rows == 0 || file.Size == 0 || len(file.SHA256) != 64 {
t.Fatalf("bad file manifest = %+v", file)
}
}
}
func TestPlanIncrementalImportDetectsTailFiles(t *testing.T) {
previous := Manifest{
Version: 1,
Tables: []TableManifest{{
Name: "things",
Columns: []string{"id", "body"},
Rows: 2,
Files: []string{"tables/things/000000.jsonl.gz"},
FileManifests: []FileManifest{{
Path: "tables/things/000000.jsonl.gz",
Rows: 2,
Size: 100,
SHA256: "old",
}},
}},
}
current := Manifest{
Version: 1,
Tables: []TableManifest{{
Name: "things",
Columns: []string{"id", "body"},
Rows: 3,
Files: []string{"tables/things/000000.jsonl.gz"},
FileManifests: []FileManifest{{
Path: "tables/things/000000.jsonl.gz",
Rows: 3,
Size: 120,
SHA256: "new",
}},
}},
}
plan := PlanIncrementalImport(previous, current)
if plan.Full || len(plan.Tables) != 1 {
t.Fatalf("plan = %+v", plan)
}
table := plan.Tables[0]
if table.Mode != TableImportFiles || len(table.Files) != 1 || table.Files[0].SHA256 != "new" {
t.Fatalf("table plan = %+v", table)
}
}
func TestPlanIncrementalImportReplacesUnsafeChanges(t *testing.T) {
previous := Manifest{
Version: 1,
Tables: []TableManifest{{
Name: "things",
Columns: []string{"id", "body"},
Rows: 2,
Files: []string{"tables/things/000000.jsonl.gz"},
FileManifests: []FileManifest{{
Path: "tables/things/000000.jsonl.gz",
Rows: 2,
Size: 100,
SHA256: "old",
}},
}},
}
current := Manifest{
Version: 1,
Tables: []TableManifest{{
Name: "things",
Columns: []string{"id", "body"},
Rows: 1,
Files: []string{"tables/things/000000.jsonl.gz"},
FileManifests: []FileManifest{{
Path: "tables/things/000000.jsonl.gz",
Rows: 1,
Size: 100,
SHA256: "new",
}},
}},
}
plan := PlanIncrementalImport(previous, current)
if plan.Full || len(plan.Tables) != 1 || plan.Tables[0].Mode != TableImportReplace {
t.Fatalf("plan = %+v", plan)
}
}
func TestImportIncrementalImportsOnlyPlannedFiles(t *testing.T) {
ctx := context.Background()
src, err := store.Open(ctx, store.Options{
Path: filepath.Join(t.TempDir(), "src.db"),
Schema: `create table things(id text primary key, body text not null);`,
})
if err != nil {
t.Fatal(err)
}
defer src.Close()
mustExec(t, src.DB(), `insert into things(id, body) values('one', 'same')`)
mustExec(t, src.DB(), `insert into things(id, body) values('two', 'old')`)
root := t.TempDir()
previous, err := Export(ctx, ExportOptions{
DB: src.DB(),
RootDir: root,
Tables: []string{"things"},
})
if err != nil {
t.Fatal(err)
}
dst, err := store.Open(ctx, store.Options{
Path: filepath.Join(t.TempDir(), "dst.db"),
Schema: `create table things(id text primary key, body text not null);`,
})
if err != nil {
t.Fatal(err)
}
defer dst.Close()
if _, err := Import(ctx, ImportOptions{DB: dst.DB(), RootDir: root}); err != nil {
t.Fatal(err)
}
mustExec(t, dst.DB(), `insert into things(id, body) values('local', 'keep')`)
mustExec(t, src.DB(), `update things set body = 'new' where id = 'two'`)
mustExec(t, src.DB(), `insert into things(id, body) values('three', 'added')`)
current, err := Export(ctx, ExportOptions{
DB: src.DB(),
RootDir: root,
Tables: []string{"things"},
})
if err != nil {
t.Fatal(err)
}
_, plan, err := ImportIncremental(ctx, IncrementalImportOptions{
DB: dst.DB(),
RootDir: root,
Previous: previous,
Current: current,
})
if err != nil {
t.Fatal(err)
}
if len(plan.Tables) != 1 || plan.Tables[0].Mode != TableImportFiles {
t.Fatalf("plan = %+v", plan)
}
var got string
if err := dst.DB().QueryRowContext(ctx, `select group_concat(id || ':' || body, ',') from (select id, body from things order by id)`).Scan(&got); err != nil {
t.Fatal(err)
}
if got != "local:keep,one:same,three:added,two:new" {
t.Fatalf("things = %q", got)
}
}
func TestImportHooks(t *testing.T) {

142
vector/vector.go Normal file
View File

@ -0,0 +1,142 @@
package vector
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"math"
"sort"
)
const DefaultRRFK = 60.0
type Scored[T any] struct {
Item T
Score float64
}
type RRFEntry[T any] struct {
Item T
Score float64
}
func EncodeFloat32(values []float32) ([]byte, error) {
buf := bytes.NewBuffer(make([]byte, 0, len(values)*4))
for _, value := range values {
if err := binary.Write(buf, binary.LittleEndian, value); err != nil {
return nil, fmt.Errorf("encode float32 vector: %w", err)
}
}
return buf.Bytes(), nil
}
func DecodeFloat32(blob []byte) ([]float32, error) {
if len(blob)%4 != 0 {
return nil, fmt.Errorf("float32 vector blob length %d is not a multiple of 4", len(blob))
}
out := make([]float32, len(blob)/4)
reader := bytes.NewReader(blob)
for i := range out {
if err := binary.Read(reader, binary.LittleEndian, &out[i]); err != nil {
return nil, fmt.Errorf("decode float32 vector: %w", err)
}
}
return out, nil
}
func ValidateDimensions(values []float32, dimensions int) error {
if dimensions <= 0 {
return errors.New("dimensions must be positive")
}
if len(values) != dimensions {
return fmt.Errorf("dimensions mismatch: got %d want %d", len(values), dimensions)
}
return nil
}
func Norm(values []float32) float64 {
var sum float64
for _, value := range values {
sum += float64(value) * float64(value)
}
return math.Sqrt(sum)
}
func CosineSimilarity(query []float32, queryNorm float64, candidate []float32) (float64, error) {
if len(candidate) != len(query) {
return 0, fmt.Errorf("dimensions mismatch: got %d want %d", len(candidate), len(query))
}
if queryNorm == 0 {
return 0, errors.New("query vector is zero")
}
candidateNorm := Norm(candidate)
if candidateNorm == 0 {
return 0, errors.New("candidate vector is zero")
}
var dot float64
for i := range query {
dot += float64(query[i]) * float64(candidate[i])
}
return dot / (queryNorm * candidateNorm), nil
}
func TopK[T any](items []Scored[T], limit int, tieLess func(left, right T) bool) []Scored[T] {
if limit <= 0 || len(items) == 0 {
return nil
}
sorted := append([]Scored[T](nil), items...)
sort.SliceStable(sorted, func(i, j int) bool {
if sorted[i].Score != sorted[j].Score {
return sorted[i].Score > sorted[j].Score
}
if tieLess == nil {
return false
}
return tieLess(sorted[i].Item, sorted[j].Item)
})
if len(sorted) > limit {
sorted = sorted[:limit]
}
return sorted
}
func ReciprocalRankFusion[T any](rankings [][]T, ids []func(T) string, weights []float64, k float64) []RRFEntry[T] {
if k <= 0 {
k = DefaultRRFK
}
entries := map[string]*RRFEntry[T]{}
for rankingIndex, ranking := range rankings {
weight := 1.0
if rankingIndex < len(weights) && weights[rankingIndex] != 0 {
weight = weights[rankingIndex]
}
var idFn func(T) string
if rankingIndex < len(ids) {
idFn = ids[rankingIndex]
}
for index, item := range ranking {
if idFn == nil {
continue
}
id := idFn(item)
if id == "" {
continue
}
entry := entries[id]
if entry == nil {
entry = &RRFEntry[T]{Item: item}
entries[id] = entry
}
entry.Score += weight / (k + float64(index+1))
}
}
out := make([]RRFEntry[T], 0, len(entries))
for _, entry := range entries {
out = append(out, *entry)
}
sort.SliceStable(out, func(i, j int) bool {
return out[i].Score > out[j].Score
})
return out
}

137
vector/vector_test.go Normal file
View File

@ -0,0 +1,137 @@
package vector
import (
"math"
"reflect"
"strings"
"testing"
)
func TestFloat32EncodingRoundTrip(t *testing.T) {
blob, err := EncodeFloat32([]float32{1, -2.5, 3.25})
require.NoError(t, err)
require.Len(t, blob, 12)
values, err := DecodeFloat32(blob)
require.NoError(t, err)
require.Equal(t, []float32{1, -2.5, 3.25}, values)
_, err = DecodeFloat32([]byte{1, 2, 3})
require.ErrorContains(t, err, "not a multiple of 4")
}
func TestCosineSimilarityAndDimensions(t *testing.T) {
require.NoError(t, ValidateDimensions([]float32{1, 2}, 2))
require.ErrorContains(t, ValidateDimensions([]float32{1}, 2), "dimensions mismatch")
require.ErrorContains(t, ValidateDimensions([]float32{1}, 0), "positive")
query := []float32{1, 0}
score, err := CosineSimilarity(query, Norm(query), []float32{0.5, 0})
require.NoError(t, err)
require.InDelta(t, 1, score, 0.0001)
_, err = CosineSimilarity(query, 0, []float32{1, 0})
require.ErrorContains(t, err, "query vector is zero")
_, err = CosineSimilarity(query, Norm(query), []float32{0, 0})
require.ErrorContains(t, err, "candidate vector is zero")
_, err = CosineSimilarity(query, Norm(query), []float32{1})
require.ErrorContains(t, err, "dimensions mismatch")
require.Equal(t, math.Sqrt(5), Norm([]float32{1, 2}))
}
func TestTopK(t *testing.T) {
items := []Scored[string]{
{Item: "c", Score: 0.3},
{Item: "a", Score: 0.5},
{Item: "b", Score: 0.5},
}
top := TopK(items, 2, func(left, right string) bool { return left < right })
require.Equal(t, []Scored[string]{{Item: "a", Score: 0.5}, {Item: "b", Score: 0.5}}, top)
require.Nil(t, TopK(items, 0, nil))
}
func TestReciprocalRankFusion(t *testing.T) {
rankings := [][]string{
{"a", "b"},
{"b", "c"},
}
ids := []func(string) string{
func(value string) string { return value },
func(value string) string { return value },
}
results := ReciprocalRankFusion(rankings, ids, []float64{1, 1}, 60)
require.Len(t, results, 3)
require.Equal(t, "b", results[0].Item)
require.Greater(t, results[0].Score, results[1].Score)
}
type requireAPI struct{}
var require requireAPI
func (requireAPI) NoError(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
func (requireAPI) Equal(t *testing.T, want, got any) {
t.Helper()
if !reflect.DeepEqual(want, got) {
t.Fatalf("not equal:\nwant: %#v\n got: %#v", want, got)
}
}
func (requireAPI) Len(t *testing.T, value any, want int) {
t.Helper()
got := reflect.ValueOf(value).Len()
if got != want {
t.Fatalf("len mismatch: got %d want %d", got, want)
}
}
func (requireAPI) Nil(t *testing.T, value any) {
t.Helper()
if !isNil(value) {
t.Fatalf("expected nil, got %#v", value)
}
}
func (requireAPI) Greater(t *testing.T, left, right float64) {
t.Helper()
if left <= right {
t.Fatalf("expected %v > %v", left, right)
}
}
func (requireAPI) InDelta(t *testing.T, want, got, delta float64) {
t.Helper()
diff := math.Abs(want - got)
if diff > delta {
t.Fatalf("not within delta: want %v got %v delta %v", want, got, delta)
}
}
func (requireAPI) ErrorContains(t *testing.T, err error, needle string) {
t.Helper()
if err == nil {
t.Fatalf("expected error containing %q, got nil", needle)
}
if !strings.Contains(err.Error(), needle) {
t.Fatalf("expected error containing %q, got %q", needle, err.Error())
}
}
func isNil(value any) bool {
if value == nil {
return true
}
reflected := reflect.ValueOf(value)
switch reflected.Kind() {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
return reflected.IsNil()
default:
return false
}
}