Compare commits

...

20 Commits

Author SHA1 Message Date
Peter Steinberger
55196d74e0
feat: add shared encrypted backup helpers
Some checks failed
CI / test (push) Has been cancelled
CodeQL / analyze (push) Has been cancelled
Security Gate: Secret Scanning / Scan for Verified Secrets (push) Has been cancelled
2026-05-08 16:41:44 +01:00
Peter Steinberger
1cc2c66283
feat: add shared embedding and vector helpers 2026-05-08 09:56:40 +01:00
Peter Steinberger
7fbca35339
docs: prepare v0.4.2 release notes 2026-05-08 07:58:41 +01:00
Peter Steinberger
4d976d782b
feat(snapshot): add incremental shard import planning 2026-05-08 07:10:12 +01:00
Vincent Koc
6a2ae79aa6
docs: note dependency updates
Some checks are pending
CI / test (push) Waiting to run
CodeQL / analyze (push) Waiting to run
Security Gate: Secret Scanning / Scan for Verified Secrets (push) Waiting to run
2026-05-07 02:52:15 -07:00
dependabot[bot]
bbc8e09c07
build(deps): bump github.com/mattn/go-isatty from 0.0.20 to 0.0.22 (#4)
Bumps [github.com/mattn/go-isatty](https://github.com/mattn/go-isatty) from 0.0.20 to 0.0.22.
- [Commits](https://github.com/mattn/go-isatty/compare/v0.0.20...v0.0.22)

---
updated-dependencies:
- dependency-name: github.com/mattn/go-isatty
  dependency-version: 0.0.22
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-07 02:45:10 -07:00
dependabot[bot]
7fc82a3e08
build(deps): bump github.com/charmbracelet/x/ansi from 0.11.6 to 0.11.7 (#3)
Bumps [github.com/charmbracelet/x/ansi](https://github.com/charmbracelet/x) from 0.11.6 to 0.11.7.
- [Commits](https://github.com/charmbracelet/x/compare/ansi/v0.11.6...ansi/v0.11.7)

---
updated-dependencies:
- dependency-name: github.com/charmbracelet/x/ansi
  dependency-version: 0.11.7
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-07 02:40:44 -07:00
dependabot[bot]
9529b8547b
build(deps): bump github.com/pelletier/go-toml/v2 from 2.3.0 to 2.3.1 (#2)
Bumps [github.com/pelletier/go-toml/v2](https://github.com/pelletier/go-toml) from 2.3.0 to 2.3.1.
- [Release notes](https://github.com/pelletier/go-toml/releases)
- [Commits](https://github.com/pelletier/go-toml/compare/v2.3.0...v2.3.1)

---
updated-dependencies:
- dependency-name: github.com/pelletier/go-toml/v2
  dependency-version: 2.3.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-07 02:40:38 -07:00
Vincent Koc
ed46d22108
docs: prepare v0.4.1 release notes
Some checks are pending
CI / test (push) Waiting to run
CodeQL / analyze (push) Waiting to run
Security Gate: Secret Scanning / Scan for Verified Secrets (push) Waiting to run
2026-05-06 14:05:30 -07:00
Vincent Koc
b5fe53da65
Merge pull request #5 from openclaw/ci-security-baseline
chore(ci): add crawl security baseline
2026-05-06 01:55:23 -07:00
Vincent Koc
ae16f4d2a0
chore(security): add verified secret scanning 2026-05-06 01:37:03 -07:00
Vincent Koc
04bc03275c
docs: split v0.4.0 changelog 2026-05-06 01:26:43 -07:00
Vincent Koc
aab645c938
chore: add funding config 2026-05-06 01:23:00 -07:00
Vincent Koc
c785cf6253
chore(ci): add stale issue automation 2026-05-06 00:30:11 -07:00
Vincent Koc
17f8be07db
chore(security): add protected automation owners 2026-05-06 00:30:10 -07:00
Vincent Koc
57630c2f95
docs: document crawlkit adoption 2026-05-05 19:16:51 -07:00
Vincent Koc
bb51c9ea12
feat(state): add legacy sync adapters 2026-05-05 17:21:30 -07:00
Vincent Koc
8cda2498b2
feat(mirror): add safer share repo helpers 2026-05-05 17:18:53 -07:00
Vincent Koc
59c0033fc7
docs: define crawlkit app boundary 2026-05-05 17:16:57 -07:00
Vincent Koc
43454a8af2
docs: add crawlkit agent guidance 2026-05-05 17:09:49 -07:00
26 changed files with 3266 additions and 66 deletions

10
.github/CODEOWNERS vendored Normal file
View File

@ -0,0 +1,10 @@
# Protect ownership and automation rules.
/.github/CODEOWNERS @openclaw/openclaw-secops
/.github/dependabot.yml @openclaw/openclaw-secops
/.github/workflows/ @openclaw/openclaw-secops
/AGENTS.md @openclaw/openclaw-secops
# Release and package integrity surfaces.
/docs/publishing.md @openclaw/openclaw-secops
/go.mod @openclaw/openclaw-secops
/go.sum @openclaw/openclaw-secops

1
.github/FUNDING.yml vendored Normal file
View File

@ -0,0 +1 @@
github: vincentkoc

63
.github/workflows/secret-scan.yml vendored Normal file
View File

@ -0,0 +1,63 @@
name: "Security Gate: Secret Scanning"
on:
push:
branches: ["**"]
pull_request:
branches: [main, master]
permissions: {}
jobs:
trufflehog:
name: Scan for Verified Secrets
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- name: Checkout code
uses: actions/checkout@v6
with:
fetch-depth: 0
- name: Resolve scan range
id: scan_range
env:
EVENT_NAME: ${{ github.event_name }}
PR_BASE_SHA: ${{ github.event.pull_request.base.sha }}
PR_HEAD_SHA: ${{ github.event.pull_request.head.sha }}
PUSH_BASE_SHA: ${{ github.event.before }}
PUSH_HEAD_SHA: ${{ github.sha }}
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
run: |
set -euo pipefail
zero_sha="0000000000000000000000000000000000000000"
if [[ "$EVENT_NAME" == "pull_request" ]]; then
base="$PR_BASE_SHA"
head="$PR_HEAD_SHA"
else
base="$PUSH_BASE_SHA"
head="$PUSH_HEAD_SHA"
if [[ -z "$base" || "$base" == "$zero_sha" ]]; then
base="origin/$DEFAULT_BRANCH"
fi
fi
echo "base=$base" >> "$GITHUB_OUTPUT"
echo "head=$head" >> "$GITHUB_OUTPUT"
- name: TruffleHog OSS
id: trufflehog
uses: trufflesecurity/trufflehog@v3.95.2
with:
path: ./
base: ${{ steps.scan_range.outputs.base }}
head: ${{ steps.scan_range.outputs.head }}
extra_args: --only-verified --debug
- name: Notify on failure
if: steps.trufflehog.outcome == 'failure'
run: |
echo "::error::Verified secrets found. Rotate the credential before merging."
exit 1

86
.github/workflows/stale.yml vendored Normal file
View File

@ -0,0 +1,86 @@
name: Stale
on:
schedule:
- cron: "17 4 * * *"
workflow_dispatch:
permissions: {}
jobs:
stale:
permissions:
issues: write
pull-requests: write
runs-on: ubuntu-latest
steps:
- name: Mark stale unassigned issues and pull requests
uses: actions/stale@v10
with:
days-before-issue-stale: 14
days-before-issue-close: 7
days-before-pr-stale: 14
days-before-pr-close: 7
stale-issue-label: stale
stale-pr-label: stale
exempt-issue-labels: enhancement,maintainer,pinned,security,no-stale
exempt-pr-labels: maintainer,no-stale
operations-per-run: 1000
ascending: true
exempt-all-assignees: true
remove-stale-when-updated: true
stale-issue-message: |
This issue has been automatically marked as stale due to inactivity.
Please add updated crawlkit details or it will be closed.
stale-pr-message: |
This pull request has been automatically marked as stale due to inactivity.
Please update it or it will be closed.
close-issue-message: |
Closing due to inactivity.
If this still affects crawlkit, open a new issue with current reproduction details.
close-issue-reason: not_planned
close-pr-message: |
Closing due to inactivity.
If this PR should be revived, reopen it with current context and validation.
- name: Mark stale assigned issues
uses: actions/stale@v10
with:
days-before-issue-stale: 30
days-before-issue-close: 10
days-before-pr-stale: -1
days-before-pr-close: -1
stale-issue-label: stale
exempt-issue-labels: enhancement,maintainer,pinned,security,no-stale
operations-per-run: 1000
ascending: true
include-only-assigned: true
remove-stale-when-updated: true
stale-issue-message: |
This assigned issue has been automatically marked as stale after 30 days of inactivity.
Please add an update or it will be closed.
close-issue-message: |
Closing due to inactivity.
If this still affects crawlkit, reopen or file a new issue with current evidence.
close-issue-reason: not_planned
- name: Mark stale assigned pull requests
uses: actions/stale@v10
with:
days-before-issue-stale: -1
days-before-issue-close: -1
days-before-pr-stale: 27
days-before-pr-close: 7
stale-pr-label: stale
exempt-pr-labels: maintainer,no-stale
operations-per-run: 1000
ascending: true
include-only-assigned: true
ignore-pr-updates: true
remove-stale-when-updated: true
stale-pr-message: |
This assigned pull request has been automatically marked as stale after being open for 27 days.
Please add an update or it will be closed.
close-pr-message: |
Closing due to inactivity.
If this PR should be revived, reopen it with current context and validation.

82
AGENTS.md Normal file
View File

@ -0,0 +1,82 @@
# AGENTS.md
## Purpose
`crawlkit` is the shared Go library for the crawl app family. It owns reusable
local archive mechanics: config paths, SQLite helpers, snapshot packing,
git-backed mirrors, sync state, CLI output helpers, terminal archive browsing,
progress logs, and safe local cache reads.
It is not a provider crawler. Keep Slack, Discord, Notion, GitHub, and other
provider-specific behavior in the downstream apps unless the abstraction is
clearly reusable across at least two apps.
Use `docs/boundary.md` as the working ownership map when deciding whether a
feature belongs in `crawlkit` or a downstream crawl app.
## Development Rules
- Keep public package nouns stable and small: `config`, `store`, `snapshot`,
`mirror`, `state`, `output`, `progress`, `tui`, `cache`, and `control`.
- Prefer additive APIs. If an API must change, preserve downstream
compatibility or update all crawl app branches in the same work cycle.
- Do not add app-specific database schema, auth, API, or cache parsing logic to
this library.
- Do not touch live app stores during tests. Use temp dirs and temp SQLite
files only.
- Use `GOWORK=off` for release and downstream-compatibility checks so local
workspaces do not hide missing tagged APIs.
## Validation
Run before handoff:
```bash
GOWORK=off go mod tidy
git diff --exit-code -- go.mod go.sum
GOWORK=off go vet ./...
GOWORK=off go test -count=1 ./...
```
For release readiness, also verify the public module tag:
```bash
GOPROXY=https://proxy.golang.org GONOSUMDB= go list -m github.com/openclaw/crawlkit@v0.5.0
```
## Downstream Compatibility
When changing exported APIs or TUI behavior, smoke the app branches with temp
home/config/cache directories:
```bash
GOWORK=off go test ./...
<app> --help
<app> --version
<app> metadata --json
<app> status --json
<app> tui --json
```
Use read-only or temp data. Never mutate `~/.gitcrawl`, `~/.slacrawl`,
`~/.discrawl`, `~/.notcrawl`, or equivalent live archives.
## TUI Standard
The shared `tui` package should track the best `gitcrawl` terminal browser
patterns: pane-aware focus, sortable table headers, mouse selection,
right-click action menus, responsive three-pane/split/stacked layouts, compact
chat/document detail rendering, clean footer status, and reliable terminal
shutdown on signals.
If a downstream app needs TUI polish that is generic, backport it here first and
then consume it from the app.
## Release Model
Go libraries are released by signed semver git tags. There is no npm, PyPI, or
Homebrew publish step for `crawlkit`.
Use patch tags for narrow fixes and minor tags for broader shared crawler or TUI
infrastructure. After tagging, prime/verify the Go proxy and then update
downstream apps to the published tag.

View File

@ -2,7 +2,41 @@
## Unreleased
## v0.5.1 - 2026-05-08
- Add reusable `backup` helpers for age identities, encrypted JSONL/Gzip shards,
manifests, recipient tracking, shard hash verification, and stale shard
cleanup.
- Add reusable `embed` providers for OpenAI, OpenAI-compatible endpoints,
Ollama, and llama.cpp, including probe diagnostics and rate-limit errors.
- Add reusable `vector` helpers for float32 blobs, dimension validation,
cosine similarity, top-k sorting, and reciprocal-rank fusion.
## v0.4.2 - 2026-05-08
- Add snapshot file fingerprints and an incremental import planner/executor so downstream apps can import changed JSONL/Gzip shards without deleting every table.
- Move the module path to `github.com/openclaw/crawlkit`.
- Bump routine Go module dependencies.
## v0.4.1 - 2026-05-06
- Add GitHub Sponsors funding metadata.
- Add crawlkit agent guidance for shared-library maintenance.
- Document downstream adoption status for `gitcrawl`, `discrawl`, `slacrawl`,
and `notcrawl`, including the app-owned provider/auth/privacy boundary.
- Document the `crawlkit` versus crawl-app boundary for embeddings, search,
inference, sync state, snapshots, SQLite, and git mirrors.
- Add safer `mirror` helpers for origin updates, existing-origin pulls,
path-scoped commits, and portable SQLite sidecar cleanup.
- Add `state.ScopedStore` and `state.CursorStore` adapters for legacy sync
state table shapes used by downstream apps.
## v0.4.0 - 2026-05-05
- Initial `crawlkit` module scaffold.
- Add the `control` package to the public package inventory for app metadata,
command manifests, status payloads, and database inventory.
- Add `tui`, a shared Bubble Tea terminal archive browser used by the crawl apps for consistent `tui` command behavior.
- Improve `tui` rows with compact column rendering, pane-specific scrolling, and full-height pane borders.
- Tune `tui` pane colors and mouse-wheel buffering to better match the `gitcrawl` terminal browser feel.

View File

@ -5,29 +5,43 @@ Shared Go infrastructure for local-first crawler archives.
`crawlkit` is not a universal Slack, Discord, Notion, or GitHub crawler. It is
the reusable foundation beneath those tools: SQLite hygiene, TOML config
defaults, portable JSONL/Gzip packing, git-backed snapshot sharing, sync state,
CLI output helpers, a shared terminal explorer, and safe desktop-cache snapshot
utilities.
CLI output helpers, control/status metadata, a shared terminal explorer, and
safe desktop-cache snapshot utilities.
## Install
```bash
go get github.com/vincentkoc/crawlkit@latest
go get github.com/openclaw/crawlkit@latest
```
Go packages are published by tagging this repository. There is no separate
package registry step. See `docs/publishing.md` for the release commands.
See `docs/boundary.md` for the crawlkit-versus-app ownership boundary.
## Packages
- `config`: standard TOML config paths, runtime dirs, and token diagnostics.
- `store`: SQLite open/read-only/transaction/query helpers.
- `snapshot`: `manifest.json` plus JSONL/Gzip table snapshot export and import.
- `snapshot`: `manifest.json` plus JSONL/Gzip table snapshot export, file fingerprints, full import, and planned incremental shard import.
- `backup`: age-encrypted JSONL/Gzip shards, backup manifests, recipient/identity helpers, and shard restore verification.
- `mirror`: clone/init/pull/commit/push helpers for private snapshot repos.
- `state`: generic crawler cursor and freshness records.
- `embed`: reusable OpenAI-compatible, Ollama, and llama.cpp embedding providers plus local probe diagnostics.
- `vector`: float32 vector encoding, dimension validation, cosine scoring, top-k helpers, and reciprocal-rank fusion.
- `output`: text/json/log output helpers.
- `control`: crawl app metadata, command manifests, status payloads, and
database inventory for launchers and automation.
- `tui`: shared terminal archive explorer with gitcrawl-style responsive panes, entity/member/detail lanes, compact sortable headers, mouse selection, floating right-click actions, sorting/filtering, and local/remote source status.
- `cache`: safe read-only local cache snapshot helpers.
## Downstream apps
- `gitcrawl` and `discrawl` consume `crawlkit` on `main`.
- `slacrawl` and `notcrawl` consume `crawlkit` on their `feat/use-crawlkit`
integration branches until those app rewires are merged.
- The apps keep provider schemas, auth, desktop/API parsing, privacy filters,
and user-facing CLI contracts. `crawlkit` owns only the reusable mechanics.
## Safety
Library tests use temporary directories. They do not touch app runtime stores

339
backup/backup.go Normal file
View File

@ -0,0 +1,339 @@
package backup
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"os"
"path"
"path/filepath"
"reflect"
"sort"
"strings"
"time"
)
const FormatVersion = 1
type Config struct {
Repo string
Identity string
Recipients []string
}
type Manifest struct {
Format int `json:"format"`
Encrypted bool `json:"encrypted"`
Exported time.Time `json:"exported"`
Recipients []string `json:"recipients,omitempty"`
Counts map[string]int `json:"counts"`
Shards []ShardEntry `json:"shards"`
}
type Shard struct {
Table string
Path string
Rows any
}
type ShardEntry struct {
Table string `json:"table"`
Path string `json:"path"`
Rows int `json:"rows"`
SHA256 string `json:"sha256"`
Bytes int64 `json:"bytes"`
}
type DecodedShard struct {
Entry ShardEntry
Plaintext []byte
}
func WriteSnapshot(ctx context.Context, cfg Config, shards []Shard, old Manifest) (Manifest, error) {
_ = ctx
recipients := normalizedStrings(cfg.Recipients)
reuseEncrypted := sameStrings(old.Recipients, recipients)
manifest := Manifest{
Format: FormatVersion,
Encrypted: true,
Exported: time.Now().UTC(),
Recipients: recipients,
Counts: map[string]int{},
}
for _, shard := range shards {
plaintext, rows, err := EncodeJSONL(shard.Rows)
if err != nil {
return Manifest{}, fmt.Errorf("encode %s: %w", shard.Table, err)
}
entry, err := WriteShard(cfg, old, shard.Table, shard.Path, plaintext, rows, reuseEncrypted)
if err != nil {
return Manifest{}, err
}
manifest.Counts[shard.Table] += rows
manifest.Shards = append(manifest.Shards, entry)
}
sort.Slice(manifest.Shards, func(i, j int) bool { return manifest.Shards[i].Path < manifest.Shards[j].Path })
if EquivalentManifest(old, manifest) {
return old, nil
}
if err := RemoveStaleShards(cfg.Repo, manifest.Shards); err != nil {
return Manifest{}, err
}
if err := WriteManifest(cfg.Repo, manifest); err != nil {
return Manifest{}, err
}
return manifest, nil
}
func ReadSnapshot(cfg Config, manifest Manifest) ([]DecodedShard, error) {
if manifest.Format != FormatVersion {
return nil, fmt.Errorf("unsupported backup format %d", manifest.Format)
}
var out []DecodedShard
for _, shard := range manifest.Shards {
plaintext, err := DecryptShardFile(cfg, shard)
if err != nil {
return nil, err
}
if got := SHA256Hex(plaintext); got != shard.SHA256 {
return nil, fmt.Errorf("backup shard hash mismatch for %s", shard.Path)
}
out = append(out, DecodedShard{Entry: shard, Plaintext: plaintext})
}
return out, nil
}
func WriteShard(cfg Config, old Manifest, table, rel string, plaintext []byte, rows int, reuseEncrypted bool) (ShardEntry, error) {
hash := SHA256Hex(plaintext)
target, err := ResolveShardPath(cfg.Repo, rel)
if err != nil {
return ShardEntry{}, err
}
if oldEntry, ok := old.Entry(rel); reuseEncrypted && ok && oldEntry.SHA256 == hash {
if info, err := os.Stat(target); err == nil {
oldEntry.Bytes = info.Size()
return oldEntry, nil
}
}
encrypted, _, err := EncryptShard(plaintext, cfg.Recipients)
if err != nil {
return ShardEntry{}, err
}
if err := os.MkdirAll(filepath.Dir(target), 0o700); err != nil {
return ShardEntry{}, err
}
if err := os.WriteFile(target, encrypted, 0o600); err != nil {
return ShardEntry{}, err
}
return ShardEntry{Table: table, Path: rel, Rows: rows, SHA256: hash, Bytes: int64(len(encrypted))}, nil
}
func DecryptShardFile(cfg Config, shard ShardEntry) ([]byte, error) {
target, err := ResolveShardPath(cfg.Repo, shard.Path)
if err != nil {
return nil, err
}
ciphertext, err := os.ReadFile(target) // #nosec G304 -- ResolveShardPath confines manifest-controlled paths below data/.
if err != nil {
return nil, err
}
return DecryptShard(ciphertext, cfg.Identity)
}
func ResolveShardPath(repo, rel string) (string, error) {
clean := path.Clean(strings.TrimSpace(rel))
if clean == "." || clean == ".." || strings.HasPrefix(clean, "../") || path.IsAbs(clean) {
return "", fmt.Errorf("backup shard path escapes backup root: %s", rel)
}
if !strings.HasPrefix(clean, "data/") || !strings.HasSuffix(clean, ".age") {
return "", fmt.Errorf("invalid backup shard path: %s", rel)
}
full := filepath.Join(repo, filepath.FromSlash(clean))
root := filepath.Clean(filepath.Join(repo, "data"))
parent := filepath.Clean(filepath.Dir(full))
if parent != root && !strings.HasPrefix(parent, root+string(filepath.Separator)) {
return "", fmt.Errorf("backup shard path escapes backup root: %s", rel)
}
return full, nil
}
func EncodeJSONL(rows any) ([]byte, int, error) {
value := reflect.ValueOf(rows)
if value.Kind() != reflect.Slice {
return nil, 0, fmt.Errorf("unsupported JSONL rows %T", rows)
}
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
for i := 0; i < value.Len(); i++ {
if err := enc.Encode(value.Index(i).Interface()); err != nil {
return nil, 0, err
}
}
return buf.Bytes(), value.Len(), nil
}
func DecodeJSONL[T any](plaintext []byte, out *[]T) error {
scanner := bufio.NewScanner(bytes.NewReader(plaintext))
scanner.Buffer(make([]byte, 0, 64*1024), 16*1024*1024)
for scanner.Scan() {
var value T
if err := json.Unmarshal(scanner.Bytes(), &value); err != nil {
return err
}
*out = append(*out, value)
}
return scanner.Err()
}
func ReadManifest(repo string) (Manifest, error) {
data, err := os.ReadFile(filepath.Join(repo, "manifest.json")) // #nosec G304 -- repo is configured by caller.
if err != nil {
return Manifest{}, err
}
var manifest Manifest
if err := json.Unmarshal(data, &manifest); err != nil {
return Manifest{}, err
}
return manifest, nil
}
func WriteManifest(repo string, manifest Manifest) error {
data, err := json.MarshalIndent(manifest, "", " ")
if err != nil {
return err
}
data = append(data, '\n')
return os.WriteFile(filepath.Join(repo, "manifest.json"), data, 0o600)
}
func (m Manifest) Entry(path string) (ShardEntry, bool) {
for _, shard := range m.Shards {
if shard.Path == path {
return shard, true
}
}
return ShardEntry{}, false
}
func EquivalentManifest(a, b Manifest) bool {
if a.Format != b.Format || a.Encrypted != b.Encrypted || !sameStrings(a.Recipients, b.Recipients) || !sameCounts(a.Counts, b.Counts) || len(a.Shards) != len(b.Shards) {
return false
}
for i := range a.Shards {
left, right := a.Shards[i], b.Shards[i]
left.Bytes, right.Bytes = 0, 0
if left != right {
return false
}
}
return true
}
func RemoveStaleShards(repo string, shards []ShardEntry) error {
keep := map[string]struct{}{}
for _, shard := range shards {
keep[filepath.Clean(filepath.Join(repo, filepath.FromSlash(shard.Path)))] = struct{}{}
}
root := filepath.Join(repo, "data")
if _, err := os.Stat(root); os.IsNotExist(err) {
return nil
}
var stale []string
if err := filepath.WalkDir(root, func(path string, d os.DirEntry, err error) error {
if err != nil || d == nil || d.IsDir() {
return err
}
if !strings.HasSuffix(path, ".age") {
return nil
}
clean := filepath.Clean(path)
if _, ok := keep[clean]; ok {
return nil
}
stale = append(stale, clean)
return nil
}); err != nil {
return err
}
for _, path := range stale {
rel, err := filepath.Rel(root, path)
if err != nil || rel == "." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) || filepath.IsAbs(rel) {
return fmt.Errorf("stale shard path escapes backup root: %s", path)
}
if err := os.Remove(path); err != nil {
return err
}
}
return nil
}
func EncryptShard(plaintext []byte, recipients []string) ([]byte, string, error) {
return encryptShard(plaintext, recipients)
}
func DecryptShard(ciphertext []byte, identityPath string) ([]byte, error) {
return decryptShard(ciphertext, identityPath)
}
func SHA256Hex(data []byte) string {
return sha256Hex(data)
}
func normalizedStrings(values []string) []string {
seen := map[string]struct{}{}
out := make([]string, 0, len(values))
for _, value := range values {
value = strings.TrimSpace(value)
if value == "" {
continue
}
if _, ok := seen[value]; ok {
continue
}
seen[value] = struct{}{}
out = append(out, value)
}
sort.Strings(out)
return out
}
func sameStrings(a, b []string) bool {
a, b = normalizedStrings(a), normalizedStrings(b)
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func sameCounts(a, b map[string]int) bool {
if len(a) != len(b) {
return false
}
for key, left := range a {
if b[key] != left {
return false
}
}
return true
}
func expandHome(p string) string {
if p == "~" {
if home, err := os.UserHomeDir(); err == nil {
return home
}
}
if after, ok := strings.CutPrefix(p, "~/"); ok {
if home, err := os.UserHomeDir(); err == nil {
return filepath.Join(home, after)
}
}
return p
}

54
backup/backup_test.go Normal file
View File

@ -0,0 +1,54 @@
package backup
import (
"context"
"os"
"path/filepath"
"testing"
)
type row struct {
ID string `json:"id"`
Body string `json:"body"`
}
func TestWriteReadEncryptedSnapshot(t *testing.T) {
dir := t.TempDir()
identity := filepath.Join(dir, "age.key")
recipient, err := EnsureIdentity(identity)
if err != nil {
t.Fatal(err)
}
cfg := Config{Repo: filepath.Join(dir, "repo"), Identity: identity, Recipients: []string{recipient}}
if err := os.MkdirAll(cfg.Repo, 0o700); err != nil {
t.Fatal(err)
}
manifest, err := WriteSnapshot(context.Background(), cfg, []Shard{
{Table: "messages", Path: "data/messages/2026/05.jsonl.gz.age", Rows: []row{{ID: "1", Body: "hello"}}},
}, Manifest{})
if err != nil {
t.Fatal(err)
}
if manifest.Counts["messages"] != 1 || len(manifest.Shards) != 1 {
t.Fatalf("unexpected manifest: %+v", manifest)
}
decoded, err := ReadSnapshot(cfg, manifest)
if err != nil {
t.Fatal(err)
}
var rows []row
if err := DecodeJSONL(decoded[0].Plaintext, &rows); err != nil {
t.Fatal(err)
}
if len(rows) != 1 || rows[0].Body != "hello" {
t.Fatalf("unexpected rows: %+v", rows)
}
}
func TestResolveShardPathRejectsEscapes(t *testing.T) {
for _, rel := range []string{"../x.age", "data/../x.age", "data/x.txt", "/data/x.age"} {
if _, err := ResolveShardPath(t.TempDir(), rel); err == nil {
t.Fatalf("expected error for %q", rel)
}
}
}

141
backup/crypto.go Normal file
View File

@ -0,0 +1,141 @@
package backup
import (
"bytes"
"compress/gzip"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"time"
"filippo.io/age"
)
func EnsureIdentity(path string) (string, error) {
path = expandHome(path)
if data, err := os.ReadFile(path); err == nil { // #nosec G304 -- path is the configured local age identity file.
identity, err := parseIdentity(data)
if err != nil {
return "", err
}
return identity.Recipient().String(), nil
} else if !os.IsNotExist(err) {
return "", err
}
identity, err := age.GenerateX25519Identity()
if err != nil {
return "", err
}
if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
return "", err
}
data := []byte(identity.String() + "\n")
if err := os.WriteFile(path, data, 0o600); err != nil {
return "", err
}
return identity.Recipient().String(), nil
}
func RecipientFromIdentity(path string) (string, error) {
data, err := os.ReadFile(expandHome(path))
if err != nil {
return "", err
}
identity, err := parseIdentity(data)
if err != nil {
return "", err
}
return identity.Recipient().String(), nil
}
func encryptShard(plaintext []byte, recipientStrings []string) ([]byte, string, error) {
recipients, err := parseRecipients(recipientStrings)
if err != nil {
return nil, "", err
}
var compressed bytes.Buffer
gz := gzip.NewWriter(&compressed)
gz.ModTime = time.Unix(0, 0).UTC()
_, _ = gz.Write(plaintext)
_ = gz.Close()
var encrypted bytes.Buffer
w, err := age.Encrypt(&encrypted, recipients...)
if err != nil {
return nil, "", err
}
_, _ = w.Write(compressed.Bytes())
if err := w.Close(); err != nil {
return nil, "", err
}
return encrypted.Bytes(), sha256Hex(plaintext), nil
}
func decryptShard(ciphertext []byte, identityPath string) ([]byte, error) {
data, err := os.ReadFile(expandHome(identityPath)) // #nosec G304 -- path is the configured local age identity file.
if err != nil {
return nil, err
}
identity, err := parseIdentity(data)
if err != nil {
return nil, err
}
r, err := age.Decrypt(bytes.NewReader(ciphertext), identity)
if err != nil {
return nil, err
}
gz, err := gzip.NewReader(r)
if err != nil {
return nil, err
}
defer func() { _ = gz.Close() }()
plaintext, err := io.ReadAll(gz)
if err != nil {
return nil, err
}
return plaintext, nil
}
func parseRecipients(values []string) ([]age.Recipient, error) {
var out []age.Recipient
for _, value := range values {
value = strings.TrimSpace(value)
if value == "" {
continue
}
recipient, err := age.ParseX25519Recipient(value)
if err != nil {
return nil, fmt.Errorf("parse age recipient: %w", err)
}
out = append(out, recipient)
}
if len(out) == 0 {
return nil, fmt.Errorf("at least one age recipient is required")
}
return out, nil
}
func parseIdentity(data []byte) (*age.X25519Identity, error) {
for _, line := range strings.Split(string(data), "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
identity, err := age.ParseX25519Identity(line)
if err != nil {
return nil, fmt.Errorf("parse age identity: %w", err)
}
return identity, nil
}
return nil, fmt.Errorf("age identity file is empty")
}
func sha256Hex(data []byte) string {
sum := sha256.Sum256(data)
return hex.EncodeToString(sum[:])
}

127
docs/boundary.md Normal file
View File

@ -0,0 +1,127 @@
# crawlkit boundary
`crawlkit` is the shared mechanics layer for local-first crawler archives. It
should make each crawl app smaller and more uniform without turning into a
generic Slack, Discord, Notion, or GitHub crawler.
The rule is simple: move behavior into `crawlkit` only when it is provider
neutral, reusable by at least two apps, and can preserve the app's existing
database and CLI contracts. Keep provider schemas, auth, API clients, cache
parsers, and product-specific ranking in the apps.
## adoption status
| app | branch | crawlkit usage | still app-owned |
| --- | --- | --- | --- |
| `gitcrawl` | `main` | config paths, SQLite openers, command/control metadata, status inventory, and the reference TUI/control contract | GitHub API sync, `gh` shim behavior, embeddings, clustering, inference, portable-store schema pruning, and the richer cluster TUI |
| `discrawl` | `main` | config/status/control, snapshot packing/import, git mirror mechanics, sync-state adapters, output helpers, and shared chat TUI | Discord bot API, desktop wiretap parsing, DM privacy filters, Discord schema, FTS/ranking, embeddings, and analytics |
| `slacrawl` | `feat/use-crawlkit` | config/status/control, snapshot packing/import, git mirror mechanics, state helpers, output helpers, and shared chat TUI | Slack API/Desktop parsing, token scopes, Slack schema, Slack text normalization, channel/thread semantics, and analytics |
| `notcrawl` | `feat/use-crawlkit` | config/status/control, snapshot packing/import, git mirror mechanics, output helpers, and shared document TUI | Notion API/Desktop parsing, Markdown rendering, page/comment/database schema, Notion FTS body construction, and data-source compatibility |
## owns
`crawlkit` should own these surfaces:
- Config paths, TOML loading defaults, runtime directories, and token
diagnostics that are the same across apps.
- SQLite connection hygiene: read-only opens, busy timeouts, WAL pragmas,
schema-version checks, transactions, safe identifier quoting, and generic
query helpers.
- Snapshot packing: manifest format, JSONL/Gzip shards, table filters,
per-file fingerprints, import progress, incremental import planning,
sidecar registration, backward-compatible manifest reads, and import
callbacks.
- Git mirror mechanics: clone/init, pull, origin management, path-scoped
commits, push retry behavior, and portable SQLite checkout cleanup.
- Sync freshness semantics: cursor/freshness records, stale checks, manifest
import state, and adapters for legacy table shapes.
- Embedding provider clients and vector math once extracted: OpenAI-compatible,
Ollama, llama.cpp, probe diagnostics, cosine search, top-k selection,
reciprocal-rank fusion, vector encoding, and dimension validation.
- FTS utilities that do not know app schemas: query escaping, snippets,
rebuild/optimize helpers, deferred refresh orchestration, and progress logs.
- Terminal archive browsing primitives: pane layout, sorting, focus, mouse
actions, menus, detail rendering primitives, and local/remote status chrome.
- Safe read-only desktop-cache snapshot helpers. The provider-specific parsing
of those snapshots stays in the apps.
## does not own
`crawlkit` should not own these surfaces:
- Slack, Discord, Notion, GitHub, or future provider API clients.
- App-specific auth flows, token scopes, rate-limit policy, and provider
object normalization.
- App database schemas for messages, pages, threads, issues, members, blocks,
comments, channels, guilds, or workspaces.
- Provider desktop-cache parsing such as Slack LevelDB records, Discord cache
rows, or Notion SQLite object trees.
- App-specific FTS bodies and ranking, such as Notion display-tree ordering,
Slack mention normalization, Discord member search, and GitHub issue/PR
syntax.
- Summarization, clustering, triage inference, or prompts until the same
behavior exists in more than one app.
- App CLI command contracts. Shared helpers can format JSON/text/log output,
but the apps decide command names, flags, backward-compatible aliases, and
deprecation behavior.
## current app seams
| app | embeddings/search/inference | sync state | snapshot, sqlite, remote |
| --- | --- | --- | --- |
| `gitcrawl` | Has the richest inference path: OpenAI-only embeddings, local thread vectors, exact cosine neighbors, durable clusters, and GitHub thread/document FTS. The vector math and portable embedding client should move to `crawlkit`; GitHub thread task construction, clustering, and prompts stay app-owned. | Uses app-owned repo sync and portable metadata. Do not force it into the shared `sync_state` table. | Has the most mature portable-store git behavior: clone/pull, dirty checkout recovery, SQLite sidecar cleanup, and portable payload pruning. The generic git/SQLite checkout pieces belong in `mirror`; GitHub portable schema pruning stays app-owned. |
| `discrawl` | Has the best reusable embedding provider surface: OpenAI, OpenAI-compatible, Ollama, llama.cpp, probe checks, float32 blobs, semantic search, hybrid search, and RRF. Provider clients, vector encoding, cosine, top-k, and RRF should be extracted. Discord message/member FTS and privacy boundaries stay app-owned. | Uses a single `scope -> cursor` table with local-only scopes such as `wiretap:*`. Shared state should adapt to this shape, not migrate it. | Uses `snapshot` and `mirror`, with important app filters for DMs and local-only sync state. Embedding bundles are sidecars today; generic sidecar/binary-vector mechanics should move to `snapshot`, while DM exclusion remains in `discrawl`. |
| `slacrawl` | Has Slack FTS and Slack text/mention normalization. Embeddings are only reserved placeholders. Slack normalization and message FTS stay app-owned. | Closest to `crawlkit/state`: `source_name`, `entity_type`, `entity_id`, `value`, `updated_at`. It is the first app that can consume shared state directly. | Uses `snapshot` and `mirror` cleanly. Its remaining share logic is mostly table lists, search-index rebuilds, and import freshness. |
| `notcrawl` | Has page/comment FTS, display-tree page bodies, deferred FTS refresh, and maintain/rebuild commands. No embeddings yet. Deferred FTS orchestration can become shared; Notion page/comment FTS content stays app-owned. | Uses `source`, `entity_type`, `entity_id`, `cursor`, `synced_at`. Shared state needs column mapping or adapters before this can be de-duped safely. | Still carries custom manifest, JSONL/Gzip, Markdown sidecars, generated-path commits, and origin update behavior. The snapshot sidecar model and mirror path-scoped commit/origin helpers should let this converge without changing the Notion DB schema. |
## extraction order
1. Harden `mirror` first.
Add origin update semantics for existing checkouts, path-scoped commits so
publish never stages unrelated files, existing-origin pull for update flows,
and portable SQLite sidecar cleanup. This is the lowest-risk de-dupe because
every app already shells out to git in similar ways.
2. Expand `snapshot` sidecars.
Keep table export/import generic, but add first-class sidecar/bundle helpers
for Markdown pages and embedding JSONL/Gzip bundles. Apps still provide
filters, table lists, delete callbacks, FTS rebuild callbacks, and privacy
rules.
3. Add state adapters instead of one forced schema.
Keep the current source/entity/value schema as the canonical new shape, but
add adapters for `scope -> cursor` and `source/entity/cursor/synced_at`
stores. `state.ScopedStore` and `state.CursorStore` cover those legacy
shapes so apps can share freshness and stale checks without risky
migrations.
4. Extract embeddings and vector search.
Start from `discrawl/internal/embed` for provider clients and from
`gitcrawl/internal/vector` plus discrawl search helpers for cosine, top-k,
vector encoding, and reciprocal-rank fusion. Apps keep task selection,
content hashing policy, provider config placement, and result persistence.
5. Add generic FTS helpers.
Provide query escaping, snippets, rebuild/optimize wrappers, deferred refresh
orchestration, and progress logging. Do not move entity-specific FTS schemas
or ranking into `crawlkit`.
6. Keep inference app-owned until there are two implementations.
`gitcrawl` clustering and summary-oriented work should not be generalized
yet. Extract only the provider/vector primitives it shares with chat/document
crawlers.
## compatibility gates
Every extraction must keep these constraints:
- Do not change existing app table shapes unless the app migration is explicitly
backward-compatible and tested against old fixtures.
- Do not change app command names, flags, JSON shape, or deprecated aliases
unless the downstream app changelog calls it out.
- Do not touch live stores during tests. Use temp homes, temp configs, and temp
SQLite files.
- Use `GOWORK=off` when proving the public `crawlkit` API so local workspaces
do not hide missing release tags.
- Keep privacy filters in the app layer. `crawlkit` can run a filter callback;
it should not know what a Discord DM or Slack private channel means.

View File

@ -14,35 +14,37 @@ go vet ./...
go test ./...
```
3. Test downstream apps against the local checkout through a temporary Go workspace.
4. Merge `crawlkit` to `main`.
5. Tag the next semver release from `main`:
3. Update docs and changelogs in `crawlkit` plus every downstream app branch
that consumes the release.
4. Test downstream apps against the local checkout through a temporary Go workspace.
5. Merge `crawlkit` to `main`.
6. Tag the next semver release from `main`:
```bash
git tag -s v0.4.0
git tag -s v0.5.0
git push origin main
git push origin v0.4.0
git push origin v0.5.0
```
6. Prime and verify module proxy visibility:
7. Prime and verify module proxy visibility:
```bash
GOPROXY=https://proxy.golang.org go list -m github.com/vincentkoc/crawlkit@v0.4.0
go list -m github.com/vincentkoc/crawlkit@v0.4.0
GOPROXY=https://proxy.golang.org go list -m github.com/openclaw/crawlkit@v0.5.0
go list -m github.com/openclaw/crawlkit@v0.5.0
```
7. Bump downstream apps to the new tag and commit their `go.mod`/`go.sum` updates:
8. Bump downstream apps to the new tag and commit their `go.mod`/`go.sum` updates:
```bash
go get github.com/vincentkoc/crawlkit@v0.4.0
go get github.com/openclaw/crawlkit@v0.5.0
go mod tidy
```
`pkg.go.dev` indexes public modules automatically after the tag is reachable.
Use a patch tag such as `v0.3.17` only for narrow bug fixes on the existing API.
Use a minor tag such as `v0.4.0` for broad shared TUI or crawler infrastructure
changes. This branch is a `v0.4.0`-shaped release.
Use a patch tag only for narrow bug fixes on the existing API. Use a minor tag
for broad crawler infrastructure changes. The module-path move needs a new tag
on `openclaw/crawlkit` before downstream apps can drop local `replace` lines.
## Versioning
@ -50,5 +52,5 @@ Keep `v0.x.y` while the downstream crawler rewires are still settling. If the
module ever reaches `v2`, Go requires the module path to become:
```text
github.com/vincentkoc/crawlkit/v2
github.com/openclaw/crawlkit/v2
```

91
embed/ollama.go Normal file
View File

@ -0,0 +1,91 @@
package embed
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
)
type ollamaProvider struct {
client *http.Client
baseURL string
model string
maxInputChars int
}
type ollamaEmbedRequest struct {
Model string `json:"model"`
Input []string `json:"input"`
}
type ollamaEmbedResponse struct {
Model string `json:"model"`
Embeddings [][]float32 `json:"embeddings"`
}
func newOllamaProvider(settings providerSettings) Provider {
return &ollamaProvider{
client: settings.HTTPClient,
baseURL: settings.BaseURL,
model: settings.Model,
maxInputChars: settings.MaxInputChars,
}
}
func (p *ollamaProvider) Embed(ctx context.Context, inputs []string) (EmbeddingBatch, error) {
if len(inputs) == 0 {
return EmbeddingBatch{Model: p.model}, nil
}
payload := ollamaEmbedRequest{
Model: p.model,
Input: trimInputs(inputs, p.maxInputChars),
}
var response ollamaEmbedResponse
if err := postJSON(ctx, p.client, p.baseURL+"/api/embed", "", payload, &response); err != nil {
return EmbeddingBatch{}, err
}
if len(response.Embeddings) != len(inputs) {
return EmbeddingBatch{}, fmt.Errorf("ollama embedding response returned %d vectors for %d inputs", len(response.Embeddings), len(inputs))
}
dimensions, err := inferDimensions(response.Embeddings)
if err != nil {
return EmbeddingBatch{}, err
}
model := response.Model
if model == "" {
model = p.model
}
return EmbeddingBatch{Model: model, Dimensions: dimensions, Vectors: response.Embeddings}, nil
}
func postJSON(ctx context.Context, client *http.Client, endpoint, apiKey string, payload any, target any) error {
body, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshal embedding request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return fmt.Errorf("build embedding request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
if apiKey != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("embedding request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
msg, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return &HTTPError{StatusCode: resp.StatusCode, Body: string(msg)}
}
if err := json.NewDecoder(resp.Body).Decode(target); err != nil {
return fmt.Errorf("decode embedding response: %w", err)
}
return nil
}

View File

@ -0,0 +1,82 @@
package embed
import (
"context"
"fmt"
"net/http"
)
type openAICompatibleProvider struct {
client *http.Client
baseURL string
apiKey string
model string
maxInputChars int
}
type openAIEmbeddingRequest struct {
Model string `json:"model"`
Input []string `json:"input"`
}
type openAIEmbeddingResponse struct {
Model string `json:"model"`
Data []openAIEmbeddingItem `json:"data"`
}
type openAIEmbeddingItem struct {
Index *int `json:"index"`
Embedding []float32 `json:"embedding"`
}
func newOpenAICompatibleProvider(settings providerSettings) Provider {
return &openAICompatibleProvider{
client: settings.HTTPClient,
baseURL: settings.BaseURL,
apiKey: settings.APIKey,
model: settings.Model,
maxInputChars: settings.MaxInputChars,
}
}
func (p *openAICompatibleProvider) Embed(ctx context.Context, inputs []string) (EmbeddingBatch, error) {
if len(inputs) == 0 {
return EmbeddingBatch{Model: p.model}, nil
}
payload := openAIEmbeddingRequest{
Model: p.model,
Input: trimInputs(inputs, p.maxInputChars),
}
var response openAIEmbeddingResponse
if err := postJSON(ctx, p.client, p.baseURL+"/embeddings", p.apiKey, payload, &response); err != nil {
return EmbeddingBatch{}, err
}
if len(response.Data) != len(inputs) {
return EmbeddingBatch{}, fmt.Errorf("openai-compatible embedding response returned %d vectors for %d inputs", len(response.Data), len(inputs))
}
vectors := make([][]float32, len(inputs))
seen := make([]bool, len(inputs))
for position, item := range response.Data {
index := position
if item.Index != nil {
index = *item.Index
}
if index < 0 || index >= len(inputs) {
return EmbeddingBatch{}, fmt.Errorf("openai-compatible embedding response index %d out of range", index)
}
if seen[index] {
return EmbeddingBatch{}, fmt.Errorf("openai-compatible embedding response duplicated index %d", index)
}
seen[index] = true
vectors[index] = item.Embedding
}
dimensions, err := inferDimensions(vectors)
if err != nil {
return EmbeddingBatch{}, err
}
model := response.Model
if model == "" {
model = p.model
}
return EmbeddingBatch{Model: model, Dimensions: dimensions, Vectors: vectors}, nil
}

317
embed/provider.go Normal file
View File

@ -0,0 +1,317 @@
package embed
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"os"
"strings"
"time"
)
const (
ProviderOpenAI = "openai"
ProviderOllama = "ollama"
ProviderLlamaCpp = "llamacpp"
ProviderOpenAICompatible = "openai_compatible"
DefaultOpenAIBaseURL = "https://api.openai.com/v1"
DefaultOllamaBaseURL = "http://127.0.0.1:11434"
DefaultLlamaCppBaseURL = "http://127.0.0.1:8080/v1"
DefaultOpenAIModel = "text-embedding-3-small"
DefaultLocalEmbeddingModel = "nomic-embed-text"
DefaultBatchSize = 64
DefaultMaxInputChars = 12000
DefaultRequestTimeout = 2 * time.Minute
DefaultProbeTimeout = 2 * time.Second
)
type Config struct {
Provider string
Model string
BaseURL string
APIKeyEnv string
RequestTimeout string
MaxInputChars int
}
type Provider interface {
Embed(ctx context.Context, inputs []string) (EmbeddingBatch, error)
}
type EmbeddingBatch struct {
Model string
Dimensions int
Vectors [][]float32
}
type HTTPError struct {
StatusCode int
Body string
}
func (e *HTTPError) Error() string {
return fmt.Sprintf("embedding request failed with HTTP %d: %s", e.StatusCode, e.Body)
}
func IsRateLimitError(err error) bool {
var httpErr *HTTPError
return errors.As(err, &httpErr) && httpErr.StatusCode == http.StatusTooManyRequests
}
type CheckResult struct {
Provider string
Model string
BaseURL string
Status string
Warning string
Probed bool
}
type Option func(*providerOptions)
type providerOptions struct {
httpClient *http.Client
timeoutOverride time.Duration
}
type providerSettings struct {
Name string
Model string
BaseURL string
APIKey string
MaxInputChars int
Timeout time.Duration
HTTPClient *http.Client
}
func WithHTTPClient(client *http.Client) Option {
return func(opts *providerOptions) {
opts.httpClient = client
}
}
func WithRequestTimeout(timeout time.Duration) Option {
return func(opts *providerOptions) {
opts.timeoutOverride = timeout
}
}
func NewProvider(cfg Config, opts ...Option) (Provider, error) {
settings, err := resolveProviderConfig(cfg, true, opts...)
if err != nil {
return nil, err
}
return newProvider(settings)
}
func CheckProvider(ctx context.Context, cfg Config) CheckResult {
settings, err := resolveProviderConfig(cfg, true, WithRequestTimeout(DefaultProbeTimeout))
if err != nil {
return CheckResult{
Provider: normalizedProviderName(cfg.Provider),
Model: strings.TrimSpace(cfg.Model),
BaseURL: strings.TrimSpace(cfg.BaseURL),
Status: "warning",
Warning: err.Error(),
}
}
result := CheckResult{
Provider: settings.Name,
Model: settings.Model,
BaseURL: settings.BaseURL,
Status: "ok",
}
if !shouldProbe(settings) {
return result
}
provider, err := newProvider(settings)
if err != nil {
result.Status = "warning"
result.Warning = err.Error()
return result
}
probeCtx, cancel := context.WithTimeout(ctx, DefaultProbeTimeout)
defer cancel()
if _, err := provider.Embed(probeCtx, []string{"crawlkit probe"}); err != nil {
result.Status = "warning"
result.Warning = err.Error()
return result
}
result.Probed = true
return result
}
func resolveProviderConfig(cfg Config, validateAPIKey bool, opts ...Option) (providerSettings, error) {
options := providerOptions{}
for _, opt := range opts {
opt(&options)
}
name := normalizedProviderName(cfg.Provider)
if name == "" {
name = ProviderOpenAI
}
model := strings.TrimSpace(cfg.Model)
if model == "" {
model = defaultModel(name)
}
baseURL := strings.TrimRight(strings.TrimSpace(cfg.BaseURL), "/")
if baseURL == "" {
switch name {
case ProviderOpenAI:
baseURL = DefaultOpenAIBaseURL
case ProviderOllama:
baseURL = DefaultOllamaBaseURL
case ProviderLlamaCpp:
baseURL = DefaultLlamaCppBaseURL
case ProviderOpenAICompatible:
return providerSettings{}, fmt.Errorf("embedding provider %q requires base_url", name)
}
}
timeout := DefaultRequestTimeout
if strings.TrimSpace(cfg.RequestTimeout) != "" {
parsed, err := time.ParseDuration(cfg.RequestTimeout)
if err != nil {
return providerSettings{}, fmt.Errorf("parse embeddings request_timeout: %w", err)
}
if parsed <= 0 {
return providerSettings{}, errors.New("embeddings request_timeout must be positive")
}
timeout = parsed
}
if options.timeoutOverride > 0 && options.timeoutOverride < timeout {
timeout = options.timeoutOverride
}
maxInputChars := cfg.MaxInputChars
if maxInputChars <= 0 {
maxInputChars = DefaultMaxInputChars
}
switch name {
case ProviderOpenAI, ProviderOllama, ProviderLlamaCpp, ProviderOpenAICompatible:
default:
return providerSettings{}, fmt.Errorf("unsupported embedding provider %q", name)
}
apiKey, err := resolveAPIKey(name, cfg.APIKeyEnv, validateAPIKey)
if err != nil {
return providerSettings{}, err
}
client := options.httpClient
if client == nil {
client = &http.Client{Timeout: timeout}
}
if _, err := url.ParseRequestURI(baseURL); err != nil {
return providerSettings{}, fmt.Errorf("invalid embeddings base_url %q: %w", baseURL, err)
}
return providerSettings{
Name: name,
Model: model,
BaseURL: baseURL,
APIKey: apiKey,
MaxInputChars: maxInputChars,
Timeout: timeout,
HTTPClient: client,
}, nil
}
func newProvider(settings providerSettings) (Provider, error) {
switch settings.Name {
case ProviderOllama:
return newOllamaProvider(settings), nil
case ProviderOpenAI, ProviderLlamaCpp, ProviderOpenAICompatible:
return newOpenAICompatibleProvider(settings), nil
default:
return nil, fmt.Errorf("unsupported embedding provider %q", settings.Name)
}
}
func resolveAPIKey(provider, apiKeyEnv string, validate bool) (string, error) {
envName := strings.TrimSpace(apiKeyEnv)
required := provider == ProviderOpenAI
if envName == "" {
if required {
envName = "OPENAI_API_KEY"
} else {
return "", nil
}
}
value := strings.TrimSpace(os.Getenv(envName))
if value == "" {
if required || validate {
return "", fmt.Errorf("embedding provider %q requires API key env %s", provider, envName)
}
return "", nil
}
return value, nil
}
func normalizedProviderName(provider string) string {
return strings.ToLower(strings.TrimSpace(provider))
}
func defaultModel(provider string) string {
switch provider {
case ProviderOllama, ProviderLlamaCpp:
return DefaultLocalEmbeddingModel
default:
return DefaultOpenAIModel
}
}
func shouldProbe(settings providerSettings) bool {
switch settings.Name {
case ProviderOllama, ProviderLlamaCpp:
return true
case ProviderOpenAICompatible:
return isLoopbackBaseURL(settings.BaseURL)
default:
return false
}
}
func isLoopbackBaseURL(rawURL string) bool {
parsed, err := url.Parse(rawURL)
if err != nil {
return false
}
host := parsed.Hostname()
if host == "localhost" {
return true
}
ip := net.ParseIP(host)
return ip != nil && ip.IsLoopback()
}
func trimInputs(inputs []string, maxChars int) []string {
if maxChars <= 0 {
maxChars = DefaultMaxInputChars
}
out := make([]string, len(inputs))
for i, input := range inputs {
runes := []rune(input)
if len(runes) > maxChars {
runes = runes[:maxChars]
}
out[i] = string(runes)
}
return out
}
func inferDimensions(vectors [][]float32) (int, error) {
dimensions := 0
for _, vector := range vectors {
if len(vector) == 0 {
return 0, errors.New("embedding response contained an empty vector")
}
if dimensions == 0 {
dimensions = len(vector)
continue
}
if len(vector) != dimensions {
return 0, fmt.Errorf("embedding response dimensions mismatch: got %d want %d", len(vector), dimensions)
}
}
return dimensions, nil
}

453
embed/provider_test.go Normal file
View File

@ -0,0 +1,453 @@
package embed
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"time"
)
func TestOllamaProviderEmbeds(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/api/embed", r.URL.Path)
assert.Equal(t, http.MethodPost, r.Method)
var req ollamaEmbedRequest
assert.NoError(t, json.NewDecoder(r.Body).Decode(&req))
assert.Equal(t, "nomic-embed-text", req.Model)
assert.Equal(t, []string{"abcd", "xy"}, req.Input)
_, _ = w.Write([]byte(`{"model":"nomic-embed-text","embeddings":[[1,2,3],[4,5,6]]}`))
}))
defer server.Close()
provider, err := NewProvider(Config{
Provider: ProviderOllama,
Model: "nomic-embed-text",
BaseURL: server.URL,
MaxInputChars: 4,
RequestTimeout: "5s",
})
require.NoError(t, err)
batch, err := provider.Embed(context.Background(), []string{"abcdef", "xy"})
require.NoError(t, err)
require.Equal(t, "nomic-embed-text", batch.Model)
require.Equal(t, 3, batch.Dimensions)
require.Equal(t, [][]float32{{1, 2, 3}, {4, 5, 6}}, batch.Vectors)
}
type assertAPI struct{}
type requireAPI struct{}
var assert assertAPI
var require requireAPI
func (assertAPI) Equal(t *testing.T, want, got any) bool {
t.Helper()
if !reflect.DeepEqual(want, got) {
t.Errorf("not equal:\nwant: %#v\n got: %#v", want, got)
return false
}
return true
}
func (assertAPI) NoError(t *testing.T, err error) bool {
t.Helper()
if err != nil {
t.Errorf("unexpected error: %v", err)
return false
}
return true
}
func (requireAPI) NoError(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
func (requireAPI) Equal(t *testing.T, want, got any) {
t.Helper()
if !reflect.DeepEqual(want, got) {
t.Fatalf("not equal:\nwant: %#v\n got: %#v", want, got)
}
}
func (requireAPI) Same(t *testing.T, want, got any) {
t.Helper()
if !reflect.ValueOf(want).IsValid() || !reflect.ValueOf(got).IsValid() ||
reflect.ValueOf(want).Pointer() != reflect.ValueOf(got).Pointer() {
t.Fatalf("not same:\nwant: %#v\n got: %#v", want, got)
}
}
func (requireAPI) True(t *testing.T, value bool) {
t.Helper()
if !value {
t.Fatal("expected true")
}
}
func (requireAPI) False(t *testing.T, value bool) {
t.Helper()
if value {
t.Fatal("expected false")
}
}
func (requireAPI) Empty(t *testing.T, value string) {
t.Helper()
if value != "" {
t.Fatalf("expected empty string, got %q", value)
}
}
func (requireAPI) Contains(t *testing.T, value, needle string) {
t.Helper()
if !strings.Contains(value, needle) {
t.Fatalf("expected %q to contain %q", value, needle)
}
}
func (requireAPI) ErrorContains(t *testing.T, err error, needle string) {
t.Helper()
if err == nil {
t.Fatalf("expected error containing %q, got nil", needle)
}
if !strings.Contains(err.Error(), needle) {
t.Fatalf("expected error containing %q, got %q", needle, err.Error())
}
}
func TestOpenAICompatibleProviderEmbedsAndUsesAuth(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/embeddings", r.URL.Path)
assert.Equal(t, "Bearer secret", r.Header.Get("Authorization"))
var req openAIEmbeddingRequest
assert.NoError(t, json.NewDecoder(r.Body).Decode(&req))
assert.Equal(t, "local-model", req.Model)
assert.Equal(t, []string{"one", "two"}, req.Input)
_, _ = w.Write([]byte(`{
"model":"local-model",
"data":[
{"index":1,"embedding":[3,4]},
{"index":0,"embedding":[1,2]}
]
}`))
}))
defer server.Close()
t.Setenv("CRAWLKIT_EMBED_KEY", "secret")
provider, err := NewProvider(Config{
Provider: ProviderOpenAICompatible,
Model: "local-model",
BaseURL: server.URL,
APIKeyEnv: "CRAWLKIT_EMBED_KEY",
RequestTimeout: "5s",
})
require.NoError(t, err)
batch, err := provider.Embed(context.Background(), []string{"one", "two"})
require.NoError(t, err)
require.Equal(t, "local-model", batch.Model)
require.Equal(t, 2, batch.Dimensions)
require.Equal(t, [][]float32{{1, 2}, {3, 4}}, batch.Vectors)
}
func TestProviderFactoryDefaultsAndValidation(t *testing.T) {
t.Setenv("OPENAI_API_KEY", "openai-secret")
openAI, err := resolveProviderConfig(Config{
Provider: ProviderOpenAI,
RequestTimeout: "5s",
}, true)
require.NoError(t, err)
require.Equal(t, DefaultOpenAIBaseURL, openAI.BaseURL)
require.Equal(t, DefaultOpenAIModel, openAI.Model)
require.Equal(t, "openai-secret", openAI.APIKey)
ollama, err := resolveProviderConfig(Config{
Provider: ProviderOllama,
RequestTimeout: "5s",
}, true)
require.NoError(t, err)
require.Equal(t, DefaultOllamaBaseURL, ollama.BaseURL)
require.Equal(t, DefaultLocalEmbeddingModel, ollama.Model)
llamaCpp, err := resolveProviderConfig(Config{
Provider: ProviderLlamaCpp,
RequestTimeout: "5s",
}, true)
require.NoError(t, err)
require.Equal(t, DefaultLlamaCppBaseURL, llamaCpp.BaseURL)
_, err = resolveProviderConfig(Config{
Provider: ProviderOpenAICompatible,
RequestTimeout: "5s",
}, true)
require.ErrorContains(t, err, "requires base_url")
}
func TestProviderFactoryRequiresOpenAIAPIKey(t *testing.T) {
t.Setenv("OPENAI_API_KEY", "")
_, err := NewProvider(Config{
Provider: ProviderOpenAI,
RequestTimeout: "5s",
})
require.ErrorContains(t, err, "requires API key env OPENAI_API_KEY")
}
func TestProviderFactoryReportsUnsupportedProviderBeforeAPIKey(t *testing.T) {
t.Setenv("MISSING_EMBED_KEY", "")
_, err := NewProvider(Config{
Provider: "bogus",
APIKeyEnv: "MISSING_EMBED_KEY",
RequestTimeout: "5s",
})
require.ErrorContains(t, err, "unsupported embedding provider \"bogus\"")
}
func TestCheckProviderProbesLocalProvider(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/api/embed", r.URL.Path)
_, _ = w.Write([]byte(`{"model":"nomic-embed-text","embeddings":[[1,2]]}`))
}))
defer server.Close()
result := CheckProvider(context.Background(), Config{
Provider: ProviderOllama,
Model: "nomic-embed-text",
BaseURL: server.URL,
RequestTimeout: "5s",
})
require.Equal(t, "ok", result.Status)
require.True(t, result.Probed)
require.Empty(t, result.Warning)
require.Equal(t, server.URL, result.BaseURL)
}
func TestCheckProviderWarnsOnLocalProbeFailure(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "not ready", http.StatusServiceUnavailable)
}))
defer server.Close()
result := CheckProvider(context.Background(), Config{
Provider: ProviderOllama,
Model: "nomic-embed-text",
BaseURL: server.URL,
RequestTimeout: "5s",
})
require.Equal(t, "warning", result.Status)
require.Contains(t, result.Warning, "HTTP 503")
require.False(t, result.Probed)
}
func TestProviderExposesRateLimitErrors(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "rate limited", http.StatusTooManyRequests)
}))
defer server.Close()
provider, err := NewProvider(Config{
Provider: ProviderOpenAICompatible,
Model: "local-model",
BaseURL: server.URL,
RequestTimeout: "5s",
})
require.NoError(t, err)
_, err = provider.Embed(context.Background(), []string{"one"})
require.ErrorContains(t, err, "HTTP 429")
require.True(t, IsRateLimitError(err))
}
func TestProviderRejectsInvalidResponses(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(`{"data":[{"index":0,"embedding":[1]},{"index":1,"embedding":[2,3]}]}`))
}))
defer server.Close()
provider, err := NewProvider(Config{
Provider: ProviderOpenAICompatible,
Model: "local-model",
BaseURL: server.URL,
RequestTimeout: "5s",
})
require.NoError(t, err)
_, err = provider.Embed(context.Background(), []string{"one", "two"})
require.ErrorContains(t, err, "dimensions mismatch")
}
func TestEmbeddingProvidersHandleEmptyInputsAndIndexErrors(t *testing.T) {
t.Parallel()
settings := providerSettings{
Name: ProviderOllama,
Model: "model",
BaseURL: "http://127.0.0.1:1",
MaxInputChars: 10,
HTTPClient: http.DefaultClient,
}
ollama := newOllamaProvider(settings)
batch, err := ollama.Embed(context.Background(), nil)
require.NoError(t, err)
require.Equal(t, "model", batch.Model)
settings.Name = ProviderOpenAICompatible
openai := newOpenAICompatibleProvider(settings)
batch, err = openai.Embed(context.Background(), nil)
require.NoError(t, err)
require.Equal(t, "model", batch.Model)
tests := []struct {
name string
body string
inputs []string
want string
}{
{name: "count", body: `{"data":[]}`, inputs: []string{"one"}, want: "returned 0 vectors for 1 inputs"},
{name: "range", body: `{"data":[{"index":2,"embedding":[1]}]}`, inputs: []string{"one"}, want: "index 2 out of range"},
{name: "duplicate", body: `{"data":[{"index":0,"embedding":[1]},{"index":0,"embedding":[2]}]}`, inputs: []string{"one", "two"}, want: "duplicated index 0"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(tc.body))
}))
defer server.Close()
provider, err := NewProvider(Config{
Provider: ProviderOpenAICompatible,
Model: "model",
BaseURL: server.URL,
RequestTimeout: "5s",
})
require.NoError(t, err)
_, err = provider.Embed(context.Background(), tc.inputs)
require.ErrorContains(t, err, tc.want)
})
}
}
func TestProviderOptionsAndProbeDecisions(t *testing.T) {
t.Parallel()
client := &http.Client{Timeout: time.Second}
settings, err := resolveProviderConfig(Config{
Provider: ProviderOllama,
BaseURL: "http://127.0.0.1:11434/",
RequestTimeout: "30s",
}, true, WithHTTPClient(client), WithRequestTimeout(50*time.Millisecond))
require.NoError(t, err)
require.Same(t, client, settings.HTTPClient)
require.Equal(t, 50*time.Millisecond, settings.Timeout)
require.Equal(t, "http://127.0.0.1:11434", settings.BaseURL)
require.True(t, shouldProbe(settings))
require.True(t, isLoopbackBaseURL("http://localhost:8080/v1"))
require.True(t, isLoopbackBaseURL("http://[::1]:8080/v1"))
require.False(t, isLoopbackBaseURL("https://api.example.com/v1"))
require.False(t, isLoopbackBaseURL("://bad"))
require.False(t, shouldProbe(providerSettings{Name: ProviderOpenAI}))
require.True(t, shouldProbe(providerSettings{Name: ProviderOpenAICompatible, BaseURL: "http://localhost:8080/v1"}))
require.False(t, shouldProbe(providerSettings{Name: ProviderOpenAICompatible, BaseURL: "https://api.example.com/v1"}))
}
func TestProviderValidationEdges(t *testing.T) {
t.Parallel()
_, err := resolveProviderConfig(Config{
Provider: ProviderOllama,
RequestTimeout: "not-a-duration",
}, true)
require.ErrorContains(t, err, "parse embeddings request_timeout")
_, err = resolveProviderConfig(Config{
Provider: ProviderOllama,
RequestTimeout: "0s",
}, true)
require.ErrorContains(t, err, "must be positive")
_, err = resolveProviderConfig(Config{
Provider: ProviderOllama,
BaseURL: "://bad",
}, true)
require.ErrorContains(t, err, "invalid embeddings base_url")
key, err := resolveAPIKey(ProviderOpenAICompatible, "MISSING_EMBED_KEY", false)
require.NoError(t, err)
require.Empty(t, key)
_, err = newProvider(providerSettings{Name: "bogus"})
require.ErrorContains(t, err, "unsupported embedding provider")
require.Equal(t, []string{"abc"}, trimInputs([]string{"abc"}, 0))
_, err = inferDimensions([][]float32{{}})
require.ErrorContains(t, err, "empty vector")
}
func TestOllamaProviderResponseEdges(t *testing.T) {
t.Parallel()
countServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/api/embed", r.URL.Path)
_, _ = w.Write([]byte(`{"embeddings":[]}`))
}))
defer countServer.Close()
provider := newOllamaProvider(providerSettings{
HTTPClient: countServer.Client(),
BaseURL: countServer.URL,
Model: "fallback-model",
MaxInputChars: 10,
})
_, err := provider.Embed(context.Background(), []string{"one"})
require.ErrorContains(t, err, "returned 0 vectors for 1 inputs")
modelServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/api/embed", r.URL.Path)
_, _ = w.Write([]byte(`{"embeddings":[[1,2]]}`))
}))
defer modelServer.Close()
provider = newOllamaProvider(providerSettings{
HTTPClient: modelServer.Client(),
BaseURL: modelServer.URL,
Model: "fallback-model",
MaxInputChars: 10,
})
batch, err := provider.Embed(context.Background(), []string{"one"})
require.NoError(t, err)
require.Equal(t, "fallback-model", batch.Model)
}
func TestCheckProviderSkipsRemoteCompatibleProbe(t *testing.T) {
t.Parallel()
result := CheckProvider(context.Background(), Config{
Provider: ProviderOpenAICompatible,
Model: "remote-model",
BaseURL: "https://api.example.com/v1",
RequestTimeout: "5s",
})
require.Equal(t, "ok", result.Status)
require.False(t, result.Probed)
require.Empty(t, result.Warning)
}

22
go.mod
View File

@ -1,31 +1,32 @@
module github.com/vincentkoc/crawlkit
module github.com/openclaw/crawlkit
go 1.26.2
require (
filippo.io/age v1.3.1
github.com/charmbracelet/bubbles v1.0.0
github.com/charmbracelet/bubbletea v1.3.10
github.com/charmbracelet/lipgloss v1.1.0
github.com/charmbracelet/x/ansi v0.11.6
github.com/charmbracelet/x/ansi v0.11.7
github.com/charmbracelet/x/term v0.2.2
github.com/mattn/go-isatty v0.0.20
github.com/pelletier/go-toml/v2 v2.3.0
github.com/mattn/go-isatty v0.0.22
github.com/pelletier/go-toml/v2 v2.3.1
modernc.org/sqlite v1.50.0
)
require (
filippo.io/hpke v0.4.0 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/charmbracelet/colorprofile v0.4.1 // indirect
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
github.com/clipperhouse/displaywidth v0.9.0 // indirect
github.com/clipperhouse/stringish v0.1.1 // indirect
github.com/clipperhouse/uax29/v2 v2.5.0 // indirect
github.com/clipperhouse/displaywidth v0.11.0 // indirect
github.com/clipperhouse/uax29/v2 v2.7.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
github.com/lucasb-eyer/go-colorful v1.4.0 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.19 // indirect
github.com/mattn/go-runewidth v0.0.23 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect
@ -33,8 +34,9 @@ require (
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
golang.org/x/crypto v0.45.0 // indirect
golang.org/x/sys v0.42.0 // indirect
golang.org/x/text v0.3.8 // indirect
golang.org/x/text v0.31.0 // indirect
modernc.org/libc v1.72.0 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect

43
go.sum
View File

@ -1,3 +1,9 @@
c2sp.org/CCTV/age v0.0.0-20251208015420-e9274a7bdbfd h1:ZLsPO6WdZ5zatV4UfVpr7oAwLGRZ+sebTUruuM4Ra3M=
c2sp.org/CCTV/age v0.0.0-20251208015420-e9274a7bdbfd/go.mod h1:SrHC2C7r5GkDk8R+NFVzYy/sdj0Ypg9htaPXQq5Cqeo=
filippo.io/age v1.3.1 h1:hbzdQOJkuaMEpRCLSN1/C5DX74RPcNCk6oqhKMXmZi0=
filippo.io/age v1.3.1/go.mod h1:EZorDTYUxt836i3zdori5IJX/v2Lj6kWFU0cfh6C0D4=
filippo.io/hpke v0.4.0 h1:p575VVQ6ted4pL+it6M00V/f2qTZITO0zgmdKCkd5+A=
filippo.io/hpke v0.4.0/go.mod h1:EmAN849/P3qdeK+PCMkDpDm83vRHM5cDipBJ8xbQLVY=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc=
@ -8,18 +14,16 @@ github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
github.com/charmbracelet/x/ansi v0.11.7 h1:kzv1kJvjg2S3r9KHo8hDdHFQLEqn4RBCb39dAYC84jI=
github.com/charmbracelet/x/ansi v0.11.7/go.mod h1:9qGpnAVYz+8ACONkZBUWPtL7lulP9No6p1epAihUZwQ=
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA=
github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA=
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/clipperhouse/displaywidth v0.11.0 h1:lBc6kY44VFw+TDx4I8opi/EtL9m20WSEFgwIwO+UVM8=
github.com/clipperhouse/displaywidth v0.11.0/go.mod h1:bkrFNkf81G8HyVqmKGxsPufD3JhNl3dSqnGhOoSD/o0=
github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJv2v7Vk=
github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
@ -30,14 +34,14 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/lucasb-eyer/go-colorful v1.4.0 h1:UtrWVfLdarDgc44HcS7pYloGHJUjHV/4FwW4TvVgFr4=
github.com/lucasb-eyer/go-colorful v1.4.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-isatty v0.0.22 h1:j8l17JJ9i6VGPUFUYoTUKPSgKe/83EYU2zBC7YNKMw4=
github.com/mattn/go-isatty v0.0.22/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4=
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3RybWcw=
github.com/mattn/go-runewidth v0.0.23/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
@ -46,14 +50,16 @@ github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/pelletier/go-toml/v2 v2.3.0 h1:k59bC/lIZREW0/iVaQR8nDHxVq8OVlIzYCOJf421CaM=
github.com/pelletier/go-toml/v2 v2.3.0/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pelletier/go-toml/v2 v2.3.1 h1:MYEvvGnQjeNkRF1qUuGolNtNExTDwct51yp7olPtrEc=
github.com/pelletier/go-toml/v2 v2.3.1/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
@ -61,11 +67,10 @@ golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY=
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
modernc.org/cc/v4 v4.27.3 h1:uNCgn37E5U09mTv1XgskEVUJ8ADKpmFMPxzGJ0TSo+U=

View File

@ -23,6 +23,9 @@ func EnsureRepo(ctx context.Context, opts Options) error {
return errors.New("repo path is required")
}
if _, err := os.Stat(filepath.Join(opts.RepoPath, ".git")); err == nil {
if opts.Remote != "" {
return setOrigin(ctx, opts)
}
return nil
}
if opts.Remote != "" {
@ -49,6 +52,17 @@ func EnsureRepo(ctx context.Context, opts Options) error {
return nil
}
func EnsureRemote(ctx context.Context, opts Options) error {
opts = normalize(opts)
if opts.Remote == "" {
return errors.New("remote is required")
}
if err := EnsureRepo(ctx, opts); err != nil {
return err
}
return setOrigin(ctx, opts)
}
func Pull(ctx context.Context, opts Options) error {
opts = normalize(opts)
if opts.Remote == "" {
@ -67,19 +81,51 @@ func Pull(ctx context.Context, opts Options) error {
return run(ctx, opts.RepoPath, opts.Git, "checkout", "-B", opts.Branch, "origin/"+opts.Branch)
}
func PullCurrent(ctx context.Context, opts Options) error {
opts = normalize(opts)
if opts.Remote != "" {
return Pull(ctx, opts)
}
if err := EnsureRepo(ctx, opts); err != nil {
return err
}
if err := run(ctx, opts.RepoPath, opts.Git, "fetch", "--prune", "origin"); err != nil {
return err
}
if _, err := output(ctx, opts.RepoPath, opts.Git, "rev-parse", "--verify", "refs/heads/"+opts.Branch); err != nil {
return run(ctx, opts.RepoPath, opts.Git, "checkout", "-B", opts.Branch, "origin/"+opts.Branch)
}
if err := run(ctx, opts.RepoPath, opts.Git, "checkout", opts.Branch); err != nil {
return err
}
return run(ctx, opts.RepoPath, opts.Git, "pull", "--ff-only", "origin", opts.Branch)
}
func Commit(ctx context.Context, opts Options, message string) (bool, error) {
return CommitPaths(ctx, opts, message, []string{"."})
}
func CommitPaths(ctx context.Context, opts Options, message string, paths []string) (bool, error) {
opts = normalize(opts)
if message == "" {
message = "archive: update snapshot"
}
if err := run(ctx, opts.RepoPath, opts.Git, "add", "."); err != nil {
return false, err
}
dirty, err := Dirty(ctx, opts)
pathspecs, err := cleanPathspecs(paths)
if err != nil {
return false, err
}
if !dirty {
if len(pathspecs) == 0 {
return false, nil
}
args := append([]string{"add", "--"}, pathspecs...)
if err := run(ctx, opts.RepoPath, opts.Git, args...); err != nil {
return false, err
}
staged, err := staged(ctx, opts)
if err != nil {
return false, err
}
if !staged {
return false, nil
}
if err := run(ctx, opts.RepoPath, opts.Git,
@ -117,6 +163,37 @@ func Dirty(ctx context.Context, opts Options) (bool, error) {
return strings.TrimSpace(out) != "", nil
}
func CleanSQLiteSidecars(rootDir string) (int, error) {
rootDir = strings.TrimSpace(rootDir)
if rootDir == "" {
return 0, errors.New("root dir is required")
}
count := 0
err := filepath.WalkDir(rootDir, func(path string, entry os.DirEntry, err error) error {
if err != nil {
return err
}
if entry.IsDir() {
if entry.Name() == ".git" {
return filepath.SkipDir
}
return nil
}
if !isSQLiteSidecar(path) {
return nil
}
if err := os.Remove(path); err != nil {
return fmt.Errorf("remove sqlite sidecar %s: %w", path, err)
}
count++
return nil
})
if err != nil {
return count, fmt.Errorf("clean sqlite sidecars: %w", err)
}
return count, nil
}
func normalize(opts Options) Options {
opts.RepoPath = strings.TrimSpace(opts.RepoPath)
opts.Remote = strings.TrimSpace(opts.Remote)
@ -131,6 +208,63 @@ func normalize(opts Options) Options {
return opts
}
func setOrigin(ctx context.Context, opts Options) error {
current, err := output(ctx, opts.RepoPath, opts.Git, "remote", "get-url", "origin")
if err != nil {
return run(ctx, opts.RepoPath, opts.Git, "remote", "add", "origin", opts.Remote)
}
if strings.TrimSpace(current) == opts.Remote {
return nil
}
return run(ctx, opts.RepoPath, opts.Git, "remote", "set-url", "origin", opts.Remote)
}
func cleanPathspecs(paths []string) ([]string, error) {
var out []string
for _, path := range paths {
path = strings.TrimSpace(path)
if path == "" {
continue
}
if filepath.IsAbs(path) {
return nil, fmt.Errorf("commit path %q must be relative", path)
}
clean := filepath.Clean(path)
if clean == "." {
out = append(out, ".")
continue
}
if clean == ".." || strings.HasPrefix(clean, ".."+string(filepath.Separator)) {
return nil, fmt.Errorf("commit path %q must stay inside the repo", path)
}
out = append(out, filepath.ToSlash(clean))
}
return out, nil
}
func staged(ctx context.Context, opts Options) (bool, error) {
opts = normalize(opts)
out, err := output(ctx, opts.RepoPath, opts.Git, "diff", "--cached", "--quiet")
if err == nil {
return false, nil
}
var exitErr *exec.ExitError
if errors.As(err, &exitErr) && exitErr.ExitCode() == 1 {
return true, nil
}
return false, fmt.Errorf("git diff --cached --quiet: %w\n%s", err, strings.TrimSpace(out))
}
func isSQLiteSidecar(path string) bool {
name := filepath.Base(path)
return strings.HasSuffix(name, ".db-wal") ||
strings.HasSuffix(name, ".db-shm") ||
strings.HasSuffix(name, ".sqlite-wal") ||
strings.HasSuffix(name, ".sqlite-shm") ||
strings.HasSuffix(name, ".sqlite3-wal") ||
strings.HasSuffix(name, ".sqlite3-shm")
}
func run(ctx context.Context, dir, git string, args ...string) error {
out, err := output(ctx, dir, git, args...)
if err != nil {

View File

@ -4,6 +4,7 @@ import (
"context"
"os"
"path/filepath"
"strings"
"testing"
)
@ -42,3 +43,146 @@ func TestEnsureRepoCommitDirty(t *testing.T) {
t.Fatal("repo should be clean after commit")
}
}
func TestEnsureRepoUpdatesExistingOrigin(t *testing.T) {
ctx := context.Background()
repo := filepath.Join(t.TempDir(), "share")
opts := Options{RepoPath: repo, Branch: "main"}
if err := EnsureRepo(ctx, opts); err != nil {
t.Fatal(err)
}
if err := run(ctx, repo, "git", "remote", "add", "origin", "https://example.invalid/old.git"); err != nil {
t.Fatal(err)
}
opts.Remote = "https://example.invalid/new.git"
if err := EnsureRepo(ctx, opts); err != nil {
t.Fatal(err)
}
out, err := output(ctx, repo, "git", "remote", "get-url", "origin")
if err != nil {
t.Fatal(err)
}
if strings.TrimSpace(out) != opts.Remote {
t.Fatalf("origin = %q, want %q", strings.TrimSpace(out), opts.Remote)
}
}
func TestCommitPathsDoesNotStageUnrelatedFiles(t *testing.T) {
ctx := context.Background()
repo := filepath.Join(t.TempDir(), "share")
opts := Options{RepoPath: repo, Branch: "main"}
if err := EnsureRepo(ctx, opts); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(repo, "manifest.json"), []byte("{}\n"), 0o600); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(repo, "notes.txt"), []byte("local draft\n"), 0o600); err != nil {
t.Fatal(err)
}
committed, err := CommitPaths(ctx, opts, "archive: manifest", []string{"manifest.json"})
if err != nil {
t.Fatal(err)
}
if !committed {
t.Fatal("expected commit")
}
tree, err := output(ctx, repo, "git", "ls-tree", "--name-only", "HEAD")
if err != nil {
t.Fatal(err)
}
if !strings.Contains(tree, "manifest.json") {
t.Fatalf("manifest was not committed: %q", tree)
}
if strings.Contains(tree, "notes.txt") {
t.Fatalf("unrelated file was committed: %q", tree)
}
status, err := output(ctx, repo, "git", "status", "--porcelain")
if err != nil {
t.Fatal(err)
}
if strings.TrimSpace(status) != "?? notes.txt" {
t.Fatalf("status = %q, want only untracked notes.txt", strings.TrimSpace(status))
}
}
func TestPullCurrentUsesExistingOrigin(t *testing.T) {
ctx := context.Background()
dir := t.TempDir()
remote := filepath.Join(dir, "remote.git")
seed := filepath.Join(dir, "seed")
repo := filepath.Join(dir, "share")
if err := run(ctx, "", "git", "init", "--bare", remote); err != nil {
t.Fatal(err)
}
if err := run(ctx, "", "git", "clone", remote, seed); err != nil {
t.Fatal(err)
}
if err := run(ctx, seed, "git", "checkout", "-B", "main"); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(seed, "manifest.json"), []byte("one\n"), 0o600); err != nil {
t.Fatal(err)
}
if err := run(ctx, seed, "git", "add", "manifest.json"); err != nil {
t.Fatal(err)
}
if err := run(ctx, seed, "git", "-c", "commit.gpgsign=false", "-c", "user.name=test", "-c", "user.email=test@example.invalid", "commit", "-m", "one"); err != nil {
t.Fatal(err)
}
if err := run(ctx, seed, "git", "push", "-u", "origin", "main"); err != nil {
t.Fatal(err)
}
if err := Pull(ctx, Options{RepoPath: repo, Remote: remote, Branch: "main"}); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(seed, "manifest.json"), []byte("two\n"), 0o600); err != nil {
t.Fatal(err)
}
if err := run(ctx, seed, "git", "add", "manifest.json"); err != nil {
t.Fatal(err)
}
if err := run(ctx, seed, "git", "-c", "commit.gpgsign=false", "-c", "user.name=test", "-c", "user.email=test@example.invalid", "commit", "-m", "two"); err != nil {
t.Fatal(err)
}
if err := run(ctx, seed, "git", "push", "origin", "main"); err != nil {
t.Fatal(err)
}
if err := PullCurrent(ctx, Options{RepoPath: repo, Branch: "main"}); err != nil {
t.Fatal(err)
}
data, err := os.ReadFile(filepath.Join(repo, "manifest.json"))
if err != nil {
t.Fatal(err)
}
if string(data) != "two\n" {
t.Fatalf("manifest = %q, want updated content", data)
}
}
func TestCleanSQLiteSidecars(t *testing.T) {
dir := t.TempDir()
files := []string{"archive.db", "archive.db-wal", "archive.db-shm", "notes.txt"}
for _, file := range files {
if err := os.WriteFile(filepath.Join(dir, file), []byte(file), 0o600); err != nil {
t.Fatal(err)
}
}
removed, err := CleanSQLiteSidecars(dir)
if err != nil {
t.Fatal(err)
}
if removed != 2 {
t.Fatalf("removed = %d, want 2", removed)
}
for _, file := range []string{"archive.db-wal", "archive.db-shm"} {
if _, err := os.Stat(filepath.Join(dir, file)); !os.IsNotExist(err) {
t.Fatalf("%s should have been removed, err=%v", file, err)
}
}
for _, file := range []string{"archive.db", "notes.txt"} {
if _, err := os.Stat(filepath.Join(dir, file)); err != nil {
t.Fatalf("%s should remain: %v", file, err)
}
}
}

View File

@ -4,10 +4,13 @@ import (
"bufio"
"compress/gzip"
"context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"hash"
"io"
"os"
"path/filepath"
@ -15,7 +18,7 @@ import (
"strings"
"time"
"github.com/vincentkoc/crawlkit/store"
"github.com/openclaw/crawlkit/store"
)
const ManifestName = "manifest.json"
@ -38,6 +41,7 @@ type ImportOptions struct {
DeleteTables []string
DeleteTable DeleteFunc
Filter RowFilter
ImportRow RowImportFunc
Progress func(ImportProgress)
BeforeImport func(context.Context, *sql.Tx) error
AfterImport func(context.Context, *sql.Tx) error
@ -45,6 +49,8 @@ type ImportOptions struct {
type RowFilter func(table string, row map[string]any) (bool, error)
type RowImportFunc func(ctx context.Context, tx *sql.Tx, table string, row map[string]any) error
type DeleteFunc func(ctx context.Context, tx *sql.Tx, table string) error
type ImportProgress struct {
@ -72,15 +78,58 @@ type Manifest struct {
}
type TableManifest struct {
Name string `json:"name"`
File string `json:"file,omitempty"`
Files []string `json:"files"`
Columns []string `json:"columns"`
Rows int `json:"rows"`
Name string `json:"name"`
File string `json:"file,omitempty"`
Files []string `json:"files"`
FileManifests []FileManifest `json:"file_manifests,omitempty"`
Columns []string `json:"columns"`
Rows int `json:"rows"`
}
type FileManifest struct {
Path string `json:"path"`
Rows int `json:"rows"`
Size int64 `json:"size,omitempty"`
SHA256 string `json:"sha256,omitempty"`
}
var ErrNoManifest = errors.New("pack manifest not found")
type TableImportMode string
const (
TableImportSkip TableImportMode = "skip"
TableImportReplace TableImportMode = "replace"
TableImportFiles TableImportMode = "files"
)
type ImportPlan struct {
Full bool
Reason string
Tables []TableImportPlan
}
type TableImportPlan struct {
Table TableManifest
Mode TableImportMode
Files []FileManifest
Reason string
}
type IncrementalImportOptions struct {
DB *sql.DB
RootDir string
Previous Manifest
Current Manifest
Plan ImportPlan
DeleteTable DeleteFunc
Filter RowFilter
ImportRow RowImportFunc
Progress func(ImportProgress)
BeforeImport func(context.Context, *sql.Tx) error
AfterImport func(context.Context, *sql.Tx) error
}
func Export(ctx context.Context, opts ExportOptions) (Manifest, error) {
if opts.DB == nil {
return Manifest{}, errors.New("db is required")
@ -170,7 +219,7 @@ func Import(ctx context.Context, opts ImportOptions) (Manifest, error) {
}
}
for _, table := range manifest.Tables {
rows, err := importTable(ctx, tx, opts.RootDir, table, opts.Filter, opts.Progress)
rows, err := importTable(ctx, tx, opts.RootDir, table, opts.Filter, opts.ImportRow, opts.Progress)
if err != nil {
return Manifest{}, err
}
@ -188,6 +237,130 @@ func Import(ctx context.Context, opts ImportOptions) (Manifest, error) {
return manifest, nil
}
func PlanIncrementalImport(previous, current Manifest) ImportPlan {
if current.Version != previous.Version {
return ImportPlan{Full: true, Reason: "manifest version changed"}
}
previousTables := make(map[string]TableManifest, len(previous.Tables))
for _, table := range previous.Tables {
previousTables[table.Name] = table
}
currentTables := make(map[string]TableManifest, len(current.Tables))
for _, table := range current.Tables {
currentTables[table.Name] = table
}
for name := range previousTables {
if _, ok := currentTables[name]; !ok {
return ImportPlan{Full: true, Reason: "table removed: " + name}
}
}
plan := ImportPlan{}
for _, table := range current.Tables {
previousTable, ok := previousTables[table.Name]
if !ok {
plan.Tables = append(plan.Tables, TableImportPlan{
Table: table,
Mode: TableImportReplace,
Files: tableFileManifests(table),
Reason: "new table",
})
continue
}
tablePlan := planTableIncrement(previousTable, table)
plan.Tables = append(plan.Tables, tablePlan)
}
return plan
}
func (p ImportPlan) Changed() bool {
if p.Full {
return true
}
for _, table := range p.Tables {
if table.Mode != TableImportSkip {
return true
}
}
return false
}
func ImportIncremental(ctx context.Context, opts IncrementalImportOptions) (Manifest, ImportPlan, error) {
if opts.DB == nil {
return Manifest{}, ImportPlan{}, errors.New("db is required")
}
current := opts.Current
var err error
if len(current.Tables) == 0 {
current, err = ReadManifest(opts.RootDir)
if err != nil {
return Manifest{}, ImportPlan{}, err
}
}
plan := opts.Plan
if len(plan.Tables) == 0 && !plan.Full && plan.Reason == "" {
plan = PlanIncrementalImport(opts.Previous, current)
}
if plan.Full {
return Manifest{}, plan, errors.New("incremental import requires a non-full plan: " + plan.Reason)
}
if !plan.Changed() {
return current, plan, nil
}
tx, err := opts.DB.BeginTx(ctx, nil)
if err != nil {
return Manifest{}, plan, fmt.Errorf("begin incremental import tx: %w", err)
}
committed := false
defer func() {
if !committed {
_ = tx.Rollback()
}
}()
if opts.BeforeImport != nil {
if err := opts.BeforeImport(ctx, tx); err != nil {
return Manifest{}, plan, err
}
}
for _, tablePlan := range plan.Tables {
switch tablePlan.Mode {
case TableImportSkip:
continue
case TableImportReplace:
if err := deleteImportTable(ctx, tx, tablePlan.Table.Name, opts.DeleteTable); err != nil {
return Manifest{}, plan, err
}
rows, err := importTable(ctx, tx, opts.RootDir, tablePlan.Table, opts.Filter, opts.ImportRow, opts.Progress)
if err != nil {
return Manifest{}, plan, err
}
reportImportProgress(opts.Progress, ImportProgress{Phase: "table_done", Table: tablePlan.Table.Name, Rows: rows, TotalRows: tablePlan.Table.Rows})
case TableImportFiles:
table := tablePlan.Table
table.File = ""
table.Files = fileManifestPaths(tablePlan.Files)
table.FileManifests = tablePlan.Files
table.Rows = fileManifestRows(tablePlan.Files)
rows, err := importTable(ctx, tx, opts.RootDir, table, opts.Filter, opts.ImportRow, opts.Progress)
if err != nil {
return Manifest{}, plan, err
}
reportImportProgress(opts.Progress, ImportProgress{Phase: "table_done", Table: tablePlan.Table.Name, Rows: rows, TotalRows: table.Rows})
default:
return Manifest{}, plan, fmt.Errorf("unknown table import mode %q for %s", tablePlan.Mode, tablePlan.Table.Name)
}
}
if opts.AfterImport != nil {
if err := opts.AfterImport(ctx, tx); err != nil {
return Manifest{}, plan, err
}
}
if err := tx.Commit(); err != nil {
return Manifest{}, plan, fmt.Errorf("commit incremental import tx: %w", err)
}
committed = true
return current, plan, nil
}
func ReadManifest(rootDir string) (Manifest, error) {
data, err := os.ReadFile(filepath.Join(rootDir, ManifestName))
if errors.Is(err, os.ErrNotExist) {
@ -278,10 +451,10 @@ func exportTable(ctx context.Context, db *sql.DB, rootDir, table string, maxShar
if err := writer.close(); err != nil {
return TableManifest{}, err
}
return TableManifest{Name: table, Files: writer.files, Columns: cols, Rows: count}, nil
return TableManifest{Name: table, Files: writer.files, FileManifests: writer.fileManifests, Columns: cols, Rows: count}, nil
}
func importTable(ctx context.Context, tx *sql.Tx, rootDir string, table TableManifest, filter RowFilter, progress func(ImportProgress)) (int, error) {
func importTable(ctx context.Context, tx *sql.Tx, rootDir string, table TableManifest, filter RowFilter, importRow RowImportFunc, progress func(ImportProgress)) (int, error) {
files := table.Files
if len(files) == 0 && strings.TrimSpace(table.File) != "" {
files = []string{table.File}
@ -299,7 +472,7 @@ func importTable(ctx context.Context, tx *sql.Tx, rootDir string, table TableMan
}
fileProgress := ImportProgress{Phase: "file_start", Table: table.Name, File: rel, FileIndex: index + 1, FileCount: len(files), TotalRows: table.Rows}
reportImportProgress(progress, fileProgress)
rows, err := importJSONLGzip(ctx, tx, file, table.Name, filter)
rows, err := importJSONLGzip(ctx, tx, file, table.Name, filter, importRow)
if err != nil {
_ = file.Close()
return totalRows, err
@ -315,7 +488,7 @@ func importTable(ctx context.Context, tx *sql.Tx, rootDir string, table TableMan
return totalRows, nil
}
func importJSONLGzip(ctx context.Context, tx *sql.Tx, reader io.Reader, table string, filter RowFilter) (int, error) {
func importJSONLGzip(ctx context.Context, tx *sql.Tx, reader io.Reader, table string, filter RowFilter, importRow RowImportFunc) (int, error) {
gz, err := gzip.NewReader(reader)
if err != nil {
return 0, fmt.Errorf("open gzip for %s: %w", table, err)
@ -341,7 +514,11 @@ func importJSONLGzip(ctx context.Context, tx *sql.Tx, reader io.Reader, table st
continue
}
}
if err := insertRow(ctx, tx, table, row); err != nil {
importFunc := importRow
if importFunc == nil {
importFunc = insertRow
}
if err := importFunc(ctx, tx, table, row); err != nil {
return rows, err
}
rows++
@ -358,6 +535,16 @@ func reportImportProgress(progress func(ImportProgress), event ImportProgress) {
}
}
func deleteImportTable(ctx context.Context, tx *sql.Tx, table string, deleteTable DeleteFunc) error {
if deleteTable != nil {
return deleteTable(ctx, tx, table)
}
if _, err := tx.ExecContext(ctx, "delete from "+store.QuoteIdent(table)); err != nil {
return fmt.Errorf("clear table %s: %w", table, err)
}
return nil
}
func insertRow(ctx context.Context, tx *sql.Tx, table string, row map[string]any) error {
cols := make([]string, 0, len(row))
for col := range row {
@ -391,8 +578,11 @@ type shardWriter struct {
nextShard int
rowsInShard int
files []string
fileManifests []FileManifest
currentRel string
file *os.File
counter *countingWriter
hasher hash.Hash
gz *gzip.Writer
}
@ -415,8 +605,10 @@ func (w *shardWriter) open() error {
w.nextShard++
w.rowsInShard = 0
w.files = append(w.files, rel)
w.currentRel = rel
w.file = file
w.counter = &countingWriter{w: file}
w.hasher = sha256.New()
w.counter = &countingWriter{w: io.MultiWriter(file, w.hasher)}
w.gz = gzip.NewWriter(w.counter)
return nil
}
@ -459,6 +651,17 @@ func (w *shardWriter) close() error {
if closeErr != nil {
return fmt.Errorf("close shard: %w", closeErr)
}
if w.currentRel != "" && w.counter != nil && w.hasher != nil {
w.fileManifests = append(w.fileManifests, FileManifest{
Path: w.currentRel,
Rows: w.rowsInShard,
Size: w.counter.n,
SHA256: hex.EncodeToString(w.hasher.Sum(nil)),
})
}
w.currentRel = ""
w.counter = nil
w.hasher = nil
return nil
}
@ -481,3 +684,119 @@ func exportValue(value any) any {
return v
}
}
func planTableIncrement(previous, current TableManifest) TableImportPlan {
if !sameStrings(previous.Columns, current.Columns) {
return TableImportPlan{Table: current, Mode: TableImportReplace, Files: tableFileManifests(current), Reason: "columns changed"}
}
previousFiles := tableFileManifests(previous)
currentFiles := tableFileManifests(current)
if len(previousFiles) == 0 && len(currentFiles) == 0 {
return TableImportPlan{Table: current, Mode: TableImportSkip, Reason: "unchanged"}
}
if !allFilesHaveFingerprints(previousFiles) || !allFilesHaveFingerprints(currentFiles) {
return TableImportPlan{Table: current, Mode: TableImportReplace, Files: currentFiles, Reason: "missing file fingerprints"}
}
if sameFileManifests(previousFiles, currentFiles) {
return TableImportPlan{Table: current, Mode: TableImportSkip, Reason: "unchanged"}
}
if len(currentFiles) < len(previousFiles) {
return TableImportPlan{Table: current, Mode: TableImportReplace, Files: currentFiles, Reason: "files removed"}
}
for i := 0; i < len(previousFiles)-1; i++ {
if !sameFileManifest(previousFiles[i], currentFiles[i]) {
return TableImportPlan{Table: current, Mode: TableImportReplace, Files: currentFiles, Reason: "non-tail file changed"}
}
}
changed := make([]FileManifest, 0, len(currentFiles)-len(previousFiles)+1)
if len(previousFiles) > 0 {
oldTail := previousFiles[len(previousFiles)-1]
newTail := currentFiles[len(previousFiles)-1]
if oldTail.Path != newTail.Path {
return TableImportPlan{Table: current, Mode: TableImportReplace, Files: currentFiles, Reason: "tail path changed"}
}
if !sameFileManifest(oldTail, newTail) {
if newTail.Rows < oldTail.Rows {
return TableImportPlan{Table: current, Mode: TableImportReplace, Files: currentFiles, Reason: "tail rows removed"}
}
changed = append(changed, newTail)
}
}
for i := len(previousFiles); i < len(currentFiles); i++ {
changed = append(changed, currentFiles[i])
}
if len(changed) == 0 {
return TableImportPlan{Table: current, Mode: TableImportSkip, Reason: "unchanged"}
}
return TableImportPlan{Table: current, Mode: TableImportFiles, Files: changed, Reason: "tail files changed"}
}
func tableFileManifests(table TableManifest) []FileManifest {
if len(table.FileManifests) > 0 {
out := make([]FileManifest, len(table.FileManifests))
copy(out, table.FileManifests)
return out
}
files := table.Files
if len(files) == 0 && strings.TrimSpace(table.File) != "" {
files = []string{table.File}
}
out := make([]FileManifest, 0, len(files))
for _, file := range files {
out = append(out, FileManifest{Path: file})
}
return out
}
func allFilesHaveFingerprints(files []FileManifest) bool {
for _, file := range files {
if file.Path == "" || file.SHA256 == "" {
return false
}
}
return true
}
func sameFileManifests(a, b []FileManifest) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if !sameFileManifest(a[i], b[i]) {
return false
}
}
return true
}
func sameFileManifest(a, b FileManifest) bool {
return a.Path == b.Path && a.Rows == b.Rows && a.Size == b.Size && a.SHA256 == b.SHA256
}
func fileManifestPaths(files []FileManifest) []string {
paths := make([]string, 0, len(files))
for _, file := range files {
paths = append(paths, file.Path)
}
return paths
}
func fileManifestRows(files []FileManifest) int {
rows := 0
for _, file := range files {
rows += file.Rows
}
return rows
}
func sameStrings(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}

View File

@ -10,7 +10,7 @@ import (
"testing"
"time"
"github.com/vincentkoc/crawlkit/store"
"github.com/openclaw/crawlkit/store"
)
func TestExportImportTablesWithFilter(t *testing.T) {
@ -95,6 +95,158 @@ func TestExportRotatesShards(t *testing.T) {
if len(manifest.Tables[0].Files) < 2 {
t.Fatalf("expected multiple shards, got %+v", manifest.Tables[0].Files)
}
if len(manifest.Tables[0].FileManifests) != len(manifest.Tables[0].Files) {
t.Fatalf("file manifests = %+v, files = %+v", manifest.Tables[0].FileManifests, manifest.Tables[0].Files)
}
for _, file := range manifest.Tables[0].FileManifests {
if file.Path == "" || file.Rows == 0 || file.Size == 0 || len(file.SHA256) != 64 {
t.Fatalf("bad file manifest = %+v", file)
}
}
}
func TestPlanIncrementalImportDetectsTailFiles(t *testing.T) {
previous := Manifest{
Version: 1,
Tables: []TableManifest{{
Name: "things",
Columns: []string{"id", "body"},
Rows: 2,
Files: []string{"tables/things/000000.jsonl.gz"},
FileManifests: []FileManifest{{
Path: "tables/things/000000.jsonl.gz",
Rows: 2,
Size: 100,
SHA256: "old",
}},
}},
}
current := Manifest{
Version: 1,
Tables: []TableManifest{{
Name: "things",
Columns: []string{"id", "body"},
Rows: 3,
Files: []string{"tables/things/000000.jsonl.gz"},
FileManifests: []FileManifest{{
Path: "tables/things/000000.jsonl.gz",
Rows: 3,
Size: 120,
SHA256: "new",
}},
}},
}
plan := PlanIncrementalImport(previous, current)
if plan.Full || len(plan.Tables) != 1 {
t.Fatalf("plan = %+v", plan)
}
table := plan.Tables[0]
if table.Mode != TableImportFiles || len(table.Files) != 1 || table.Files[0].SHA256 != "new" {
t.Fatalf("table plan = %+v", table)
}
}
func TestPlanIncrementalImportReplacesUnsafeChanges(t *testing.T) {
previous := Manifest{
Version: 1,
Tables: []TableManifest{{
Name: "things",
Columns: []string{"id", "body"},
Rows: 2,
Files: []string{"tables/things/000000.jsonl.gz"},
FileManifests: []FileManifest{{
Path: "tables/things/000000.jsonl.gz",
Rows: 2,
Size: 100,
SHA256: "old",
}},
}},
}
current := Manifest{
Version: 1,
Tables: []TableManifest{{
Name: "things",
Columns: []string{"id", "body"},
Rows: 1,
Files: []string{"tables/things/000000.jsonl.gz"},
FileManifests: []FileManifest{{
Path: "tables/things/000000.jsonl.gz",
Rows: 1,
Size: 100,
SHA256: "new",
}},
}},
}
plan := PlanIncrementalImport(previous, current)
if plan.Full || len(plan.Tables) != 1 || plan.Tables[0].Mode != TableImportReplace {
t.Fatalf("plan = %+v", plan)
}
}
func TestImportIncrementalImportsOnlyPlannedFiles(t *testing.T) {
ctx := context.Background()
src, err := store.Open(ctx, store.Options{
Path: filepath.Join(t.TempDir(), "src.db"),
Schema: `create table things(id text primary key, body text not null);`,
})
if err != nil {
t.Fatal(err)
}
defer src.Close()
mustExec(t, src.DB(), `insert into things(id, body) values('one', 'same')`)
mustExec(t, src.DB(), `insert into things(id, body) values('two', 'old')`)
root := t.TempDir()
previous, err := Export(ctx, ExportOptions{
DB: src.DB(),
RootDir: root,
Tables: []string{"things"},
})
if err != nil {
t.Fatal(err)
}
dst, err := store.Open(ctx, store.Options{
Path: filepath.Join(t.TempDir(), "dst.db"),
Schema: `create table things(id text primary key, body text not null);`,
})
if err != nil {
t.Fatal(err)
}
defer dst.Close()
if _, err := Import(ctx, ImportOptions{DB: dst.DB(), RootDir: root}); err != nil {
t.Fatal(err)
}
mustExec(t, dst.DB(), `insert into things(id, body) values('local', 'keep')`)
mustExec(t, src.DB(), `update things set body = 'new' where id = 'two'`)
mustExec(t, src.DB(), `insert into things(id, body) values('three', 'added')`)
current, err := Export(ctx, ExportOptions{
DB: src.DB(),
RootDir: root,
Tables: []string{"things"},
})
if err != nil {
t.Fatal(err)
}
_, plan, err := ImportIncremental(ctx, IncrementalImportOptions{
DB: dst.DB(),
RootDir: root,
Previous: previous,
Current: current,
})
if err != nil {
t.Fatal(err)
}
if len(plan.Tables) != 1 || plan.Tables[0].Mode != TableImportFiles {
t.Fatalf("plan = %+v", plan)
}
var got string
if err := dst.DB().QueryRowContext(ctx, `select group_concat(id || ':' || body, ',') from (select id, body from things order by id)`).Scan(&got); err != nil {
t.Fatal(err)
}
if got != "local:keep,one:same,three:added,two:new" {
t.Fatalf("things = %q", got)
}
}
func TestImportHooks(t *testing.T) {

191
state/adapters.go Normal file
View File

@ -0,0 +1,191 @@
package state
import (
"context"
"database/sql"
"fmt"
"time"
)
const ScopedSchema = `
create table if not exists sync_state (
scope text primary key,
cursor text not null,
updated_at text not null
);
create index if not exists idx_sync_state_updated_at on sync_state(updated_at desc);
`
const CursorSchema = `
create table if not exists sync_state (
source text not null,
entity_type text not null,
entity_id text not null,
cursor text not null,
synced_at text not null,
primary key (source, entity_type, entity_id)
);
create index if not exists idx_sync_state_synced_at on sync_state(synced_at desc);
`
type ScopedStore struct {
db execQuerier
now func() time.Time
}
type ScopedRecord struct {
Scope string `json:"scope"`
Cursor string `json:"cursor"`
UpdatedAt time.Time `json:"updated_at"`
}
type CursorStore struct {
db execQuerier
now func() time.Time
}
type CursorRecord struct {
Source string `json:"source"`
EntityType string `json:"entity_type"`
EntityID string `json:"entity_id"`
Cursor string `json:"cursor"`
SyncedAt time.Time `json:"synced_at"`
}
func EnsureScopedSchema(ctx context.Context, db execQuerier) error {
if _, err := db.ExecContext(ctx, ScopedSchema); err != nil {
return fmt.Errorf("ensure scoped sync_state schema: %w", err)
}
return nil
}
func EnsureCursorSchema(ctx context.Context, db execQuerier) error {
if _, err := db.ExecContext(ctx, CursorSchema); err != nil {
return fmt.Errorf("ensure cursor sync_state schema: %w", err)
}
return nil
}
func NewScoped(db execQuerier) *ScopedStore {
return NewScopedWithClock(db, nil)
}
func NewScopedWithClock(db execQuerier, now func() time.Time) *ScopedStore {
if now == nil {
now = func() time.Time { return time.Now().UTC() }
}
return &ScopedStore{db: db, now: now}
}
func (s *ScopedStore) Set(ctx context.Context, scope, cursor string) error {
updatedAt := s.now().UTC()
_, err := s.db.ExecContext(ctx, `
insert into sync_state(scope, cursor, updated_at)
values (?, ?, ?)
on conflict(scope) do update set
cursor = excluded.cursor,
updated_at = excluded.updated_at
`, scope, cursor, updatedAt.Format(time.RFC3339Nano))
if err != nil {
return fmt.Errorf("set scoped sync state: %w", err)
}
return nil
}
func (s *ScopedStore) Get(ctx context.Context, scope string) (ScopedRecord, bool, error) {
var rec ScopedRecord
var updatedAt string
err := s.db.QueryRowContext(ctx, `
select scope, cursor, updated_at
from sync_state
where scope = ?
`, scope).Scan(&rec.Scope, &rec.Cursor, &updatedAt)
if err == sql.ErrNoRows {
return ScopedRecord{}, false, nil
}
if err != nil {
return ScopedRecord{}, false, err
}
parsed, err := time.Parse(time.RFC3339Nano, updatedAt)
if err != nil {
return ScopedRecord{}, false, fmt.Errorf("parse scoped sync state updated_at: %w", err)
}
rec.UpdatedAt = parsed
return rec, true, nil
}
func (s *ScopedStore) IsStale(ctx context.Context, scope string, maxAge time.Duration) (bool, error) {
rec, ok, err := s.Get(ctx, scope)
if err != nil {
return false, err
}
if !ok {
return true, nil
}
if maxAge <= 0 {
return false, nil
}
return s.now().UTC().Sub(rec.UpdatedAt) > maxAge, nil
}
func NewCursor(db execQuerier) *CursorStore {
return NewCursorWithClock(db, nil)
}
func NewCursorWithClock(db execQuerier, now func() time.Time) *CursorStore {
if now == nil {
now = func() time.Time { return time.Now().UTC() }
}
return &CursorStore{db: db, now: now}
}
func (s *CursorStore) Set(ctx context.Context, source, entityType, entityID, cursor string) error {
syncedAt := s.now().UTC()
_, err := s.db.ExecContext(ctx, `
insert into sync_state(source, entity_type, entity_id, cursor, synced_at)
values (?, ?, ?, ?, ?)
on conflict(source, entity_type, entity_id) do update set
cursor = excluded.cursor,
synced_at = excluded.synced_at
`, source, entityType, entityID, cursor, syncedAt.Format(time.RFC3339Nano))
if err != nil {
return fmt.Errorf("set cursor sync state: %w", err)
}
return nil
}
func (s *CursorStore) Get(ctx context.Context, source, entityType, entityID string) (CursorRecord, bool, error) {
var rec CursorRecord
var syncedAt string
err := s.db.QueryRowContext(ctx, `
select source, entity_type, entity_id, cursor, synced_at
from sync_state
where source = ? and entity_type = ? and entity_id = ?
`, source, entityType, entityID).Scan(&rec.Source, &rec.EntityType, &rec.EntityID, &rec.Cursor, &syncedAt)
if err == sql.ErrNoRows {
return CursorRecord{}, false, nil
}
if err != nil {
return CursorRecord{}, false, err
}
parsed, err := time.Parse(time.RFC3339Nano, syncedAt)
if err != nil {
return CursorRecord{}, false, fmt.Errorf("parse cursor sync state synced_at: %w", err)
}
rec.SyncedAt = parsed
return rec, true, nil
}
func (s *CursorStore) IsStale(ctx context.Context, source, entityType, entityID string, maxAge time.Duration) (bool, error) {
rec, ok, err := s.Get(ctx, source, entityType, entityID)
if err != nil {
return false, err
}
if !ok {
return true, nil
}
if maxAge <= 0 {
return false, nil
}
return s.now().UTC().Sub(rec.SyncedAt) > maxAge, nil
}

View File

@ -48,3 +48,81 @@ func TestSetGetAndStale(t *testing.T) {
t.Fatal("old record reported fresh")
}
}
func TestScopedStoreSetGetAndStale(t *testing.T) {
ctx := context.Background()
db, err := sql.Open("sqlite", "file:"+filepath.Join(t.TempDir(), "scoped.db"))
if err != nil {
t.Fatal(err)
}
defer db.Close()
if err := EnsureScopedSchema(ctx, db); err != nil {
t.Fatal(err)
}
now := time.Date(2026, 5, 1, 12, 0, 0, 0, time.UTC)
store := NewScopedWithClock(db, func() time.Time { return now })
if err := store.Set(ctx, "share:last_import_at", "2026-05-01T12:00:00Z"); err != nil {
t.Fatal(err)
}
rec, ok, err := store.Get(ctx, "share:last_import_at")
if err != nil {
t.Fatal(err)
}
if !ok || rec.Cursor == "" {
t.Fatalf("record not found: %+v", rec)
}
stale, err := store.IsStale(ctx, "share:last_import_at", time.Hour)
if err != nil {
t.Fatal(err)
}
if stale {
t.Fatal("fresh scoped record reported stale")
}
store.now = func() time.Time { return now.Add(2 * time.Hour) }
stale, err = store.IsStale(ctx, "share:last_import_at", time.Hour)
if err != nil {
t.Fatal(err)
}
if !stale {
t.Fatal("old scoped record reported fresh")
}
}
func TestCursorStoreSetGetAndStale(t *testing.T) {
ctx := context.Background()
db, err := sql.Open("sqlite", "file:"+filepath.Join(t.TempDir(), "cursor.db"))
if err != nil {
t.Fatal(err)
}
defer db.Close()
if err := EnsureCursorSchema(ctx, db); err != nil {
t.Fatal(err)
}
now := time.Date(2026, 5, 1, 12, 0, 0, 0, time.UTC)
store := NewCursorWithClock(db, func() time.Time { return now })
if err := store.Set(ctx, "share", "manifest", "generated_at", "2026-05-01T12:00:00Z"); err != nil {
t.Fatal(err)
}
rec, ok, err := store.Get(ctx, "share", "manifest", "generated_at")
if err != nil {
t.Fatal(err)
}
if !ok || rec.Cursor == "" {
t.Fatalf("record not found: %+v", rec)
}
stale, err := store.IsStale(ctx, "share", "manifest", "generated_at", time.Hour)
if err != nil {
t.Fatal(err)
}
if stale {
t.Fatal("fresh cursor record reported stale")
}
store.now = func() time.Time { return now.Add(2 * time.Hour) }
stale, err = store.IsStale(ctx, "share", "manifest", "generated_at", time.Hour)
if err != nil {
t.Fatal(err)
}
if !stale {
t.Fatal("old cursor record reported fresh")
}
}

142
vector/vector.go Normal file
View File

@ -0,0 +1,142 @@
package vector
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"math"
"sort"
)
const DefaultRRFK = 60.0
type Scored[T any] struct {
Item T
Score float64
}
type RRFEntry[T any] struct {
Item T
Score float64
}
func EncodeFloat32(values []float32) ([]byte, error) {
buf := bytes.NewBuffer(make([]byte, 0, len(values)*4))
for _, value := range values {
if err := binary.Write(buf, binary.LittleEndian, value); err != nil {
return nil, fmt.Errorf("encode float32 vector: %w", err)
}
}
return buf.Bytes(), nil
}
func DecodeFloat32(blob []byte) ([]float32, error) {
if len(blob)%4 != 0 {
return nil, fmt.Errorf("float32 vector blob length %d is not a multiple of 4", len(blob))
}
out := make([]float32, len(blob)/4)
reader := bytes.NewReader(blob)
for i := range out {
if err := binary.Read(reader, binary.LittleEndian, &out[i]); err != nil {
return nil, fmt.Errorf("decode float32 vector: %w", err)
}
}
return out, nil
}
func ValidateDimensions(values []float32, dimensions int) error {
if dimensions <= 0 {
return errors.New("dimensions must be positive")
}
if len(values) != dimensions {
return fmt.Errorf("dimensions mismatch: got %d want %d", len(values), dimensions)
}
return nil
}
func Norm(values []float32) float64 {
var sum float64
for _, value := range values {
sum += float64(value) * float64(value)
}
return math.Sqrt(sum)
}
func CosineSimilarity(query []float32, queryNorm float64, candidate []float32) (float64, error) {
if len(candidate) != len(query) {
return 0, fmt.Errorf("dimensions mismatch: got %d want %d", len(candidate), len(query))
}
if queryNorm == 0 {
return 0, errors.New("query vector is zero")
}
candidateNorm := Norm(candidate)
if candidateNorm == 0 {
return 0, errors.New("candidate vector is zero")
}
var dot float64
for i := range query {
dot += float64(query[i]) * float64(candidate[i])
}
return dot / (queryNorm * candidateNorm), nil
}
func TopK[T any](items []Scored[T], limit int, tieLess func(left, right T) bool) []Scored[T] {
if limit <= 0 || len(items) == 0 {
return nil
}
sorted := append([]Scored[T](nil), items...)
sort.SliceStable(sorted, func(i, j int) bool {
if sorted[i].Score != sorted[j].Score {
return sorted[i].Score > sorted[j].Score
}
if tieLess == nil {
return false
}
return tieLess(sorted[i].Item, sorted[j].Item)
})
if len(sorted) > limit {
sorted = sorted[:limit]
}
return sorted
}
func ReciprocalRankFusion[T any](rankings [][]T, ids []func(T) string, weights []float64, k float64) []RRFEntry[T] {
if k <= 0 {
k = DefaultRRFK
}
entries := map[string]*RRFEntry[T]{}
for rankingIndex, ranking := range rankings {
weight := 1.0
if rankingIndex < len(weights) && weights[rankingIndex] != 0 {
weight = weights[rankingIndex]
}
var idFn func(T) string
if rankingIndex < len(ids) {
idFn = ids[rankingIndex]
}
for index, item := range ranking {
if idFn == nil {
continue
}
id := idFn(item)
if id == "" {
continue
}
entry := entries[id]
if entry == nil {
entry = &RRFEntry[T]{Item: item}
entries[id] = entry
}
entry.Score += weight / (k + float64(index+1))
}
}
out := make([]RRFEntry[T], 0, len(entries))
for _, entry := range entries {
out = append(out, *entry)
}
sort.SliceStable(out, func(i, j int) bool {
return out[i].Score > out[j].Score
})
return out
}

137
vector/vector_test.go Normal file
View File

@ -0,0 +1,137 @@
package vector
import (
"math"
"reflect"
"strings"
"testing"
)
func TestFloat32EncodingRoundTrip(t *testing.T) {
blob, err := EncodeFloat32([]float32{1, -2.5, 3.25})
require.NoError(t, err)
require.Len(t, blob, 12)
values, err := DecodeFloat32(blob)
require.NoError(t, err)
require.Equal(t, []float32{1, -2.5, 3.25}, values)
_, err = DecodeFloat32([]byte{1, 2, 3})
require.ErrorContains(t, err, "not a multiple of 4")
}
func TestCosineSimilarityAndDimensions(t *testing.T) {
require.NoError(t, ValidateDimensions([]float32{1, 2}, 2))
require.ErrorContains(t, ValidateDimensions([]float32{1}, 2), "dimensions mismatch")
require.ErrorContains(t, ValidateDimensions([]float32{1}, 0), "positive")
query := []float32{1, 0}
score, err := CosineSimilarity(query, Norm(query), []float32{0.5, 0})
require.NoError(t, err)
require.InDelta(t, 1, score, 0.0001)
_, err = CosineSimilarity(query, 0, []float32{1, 0})
require.ErrorContains(t, err, "query vector is zero")
_, err = CosineSimilarity(query, Norm(query), []float32{0, 0})
require.ErrorContains(t, err, "candidate vector is zero")
_, err = CosineSimilarity(query, Norm(query), []float32{1})
require.ErrorContains(t, err, "dimensions mismatch")
require.Equal(t, math.Sqrt(5), Norm([]float32{1, 2}))
}
func TestTopK(t *testing.T) {
items := []Scored[string]{
{Item: "c", Score: 0.3},
{Item: "a", Score: 0.5},
{Item: "b", Score: 0.5},
}
top := TopK(items, 2, func(left, right string) bool { return left < right })
require.Equal(t, []Scored[string]{{Item: "a", Score: 0.5}, {Item: "b", Score: 0.5}}, top)
require.Nil(t, TopK(items, 0, nil))
}
func TestReciprocalRankFusion(t *testing.T) {
rankings := [][]string{
{"a", "b"},
{"b", "c"},
}
ids := []func(string) string{
func(value string) string { return value },
func(value string) string { return value },
}
results := ReciprocalRankFusion(rankings, ids, []float64{1, 1}, 60)
require.Len(t, results, 3)
require.Equal(t, "b", results[0].Item)
require.Greater(t, results[0].Score, results[1].Score)
}
type requireAPI struct{}
var require requireAPI
func (requireAPI) NoError(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
func (requireAPI) Equal(t *testing.T, want, got any) {
t.Helper()
if !reflect.DeepEqual(want, got) {
t.Fatalf("not equal:\nwant: %#v\n got: %#v", want, got)
}
}
func (requireAPI) Len(t *testing.T, value any, want int) {
t.Helper()
got := reflect.ValueOf(value).Len()
if got != want {
t.Fatalf("len mismatch: got %d want %d", got, want)
}
}
func (requireAPI) Nil(t *testing.T, value any) {
t.Helper()
if !isNil(value) {
t.Fatalf("expected nil, got %#v", value)
}
}
func (requireAPI) Greater(t *testing.T, left, right float64) {
t.Helper()
if left <= right {
t.Fatalf("expected %v > %v", left, right)
}
}
func (requireAPI) InDelta(t *testing.T, want, got, delta float64) {
t.Helper()
diff := math.Abs(want - got)
if diff > delta {
t.Fatalf("not within delta: want %v got %v delta %v", want, got, delta)
}
}
func (requireAPI) ErrorContains(t *testing.T, err error, needle string) {
t.Helper()
if err == nil {
t.Fatalf("expected error containing %q, got nil", needle)
}
if !strings.Contains(err.Error(), needle) {
t.Fatalf("expected error containing %q, got %q", needle, err.Error())
}
}
func isNil(value any) bool {
if value == nil {
return true
}
reflected := reflect.ValueOf(value)
switch reflected.Kind() {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
return reflected.IsNil()
default:
return false
}
}