Compare commits
18 Commits
ci/codeql-
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a1def2c98f | ||
|
|
be98cde23f | ||
|
|
b52eefaa40 | ||
|
|
733714a5e7 | ||
|
|
40c787c54a | ||
|
|
40317aa538 | ||
|
|
fb969672e0 | ||
|
|
67c6f4655b | ||
|
|
335a95bd66 | ||
|
|
d8c8778f19 | ||
|
|
eeb10dcd30 | ||
|
|
016e849e3c | ||
|
|
89d35a67a4 | ||
|
|
0da02de393 | ||
|
|
14dd5478f4 | ||
|
|
4b4303556a | ||
|
|
abcb77e6fc | ||
|
|
f328cfba2f |
@ -1,36 +1,59 @@
|
||||
---
|
||||
name: discrawl
|
||||
description: Use for local Discord archive search, sync freshness, DMs, channel summaries, and Discrawl repo/release work.
|
||||
description: Use for local Discord archive search, sync freshness, DMs, channel summaries, desktop/API/git-share sources, TUI browsing, and Discrawl repo/release work.
|
||||
---
|
||||
|
||||
# Discrawl
|
||||
|
||||
Use local archive data first for Discord questions. Browse or hit live APIs only when the local archive is stale or the user asks for current external context.
|
||||
Use local Discord archive data first for Discord questions. Hit Discord APIs
|
||||
only when the archive is stale, missing the requested scope, or the user asks
|
||||
for current external context.
|
||||
|
||||
## Sources
|
||||
|
||||
- DB: `~/.discrawl/discrawl.db`
|
||||
- Config: `~/.discrawl/config.toml`
|
||||
- Repo: `~/Projects/discrawl`
|
||||
- Preferred CLI: `discrawl`; fallback to repo binary if installed binary is stale
|
||||
- Cache: `~/.discrawl/cache`
|
||||
- Logs: `~/.discrawl/logs`
|
||||
- Git share repo: `~/.discrawl/share`
|
||||
- Repo: `openclaw/discrawl`; use `~/GIT/_Perso/discrawl` only after verifying
|
||||
its remote targets `openclaw/discrawl`, otherwise use a fresh checkout
|
||||
- Preferred CLI: `discrawl`; fallback to `go run ./cmd/discrawl` from the repo if the installed binary is stale
|
||||
|
||||
## Freshness
|
||||
|
||||
For recent/current questions, check freshness before analysis:
|
||||
|
||||
```bash
|
||||
discrawl status --json
|
||||
```
|
||||
|
||||
For precise freshness from the default database:
|
||||
|
||||
```bash
|
||||
sqlite3 ~/.discrawl/discrawl.db \
|
||||
"select coalesce(max(updated_at),'') from sync_state where scope like 'channel:%';"
|
||||
```
|
||||
|
||||
Routine refresh:
|
||||
Routine diagnostics:
|
||||
|
||||
```bash
|
||||
discrawl doctor
|
||||
```
|
||||
|
||||
Desktop-local refresh:
|
||||
|
||||
```bash
|
||||
discrawl sync --source wiretap
|
||||
```
|
||||
|
||||
Bot API latest refresh, when credentials are available:
|
||||
|
||||
```bash
|
||||
discrawl sync
|
||||
```
|
||||
|
||||
Historical/backfill refresh:
|
||||
Use `--full` only for deliberate historical backfills:
|
||||
|
||||
```bash
|
||||
discrawl sync --full
|
||||
@ -42,7 +65,7 @@ If SQLite reports busy/locked, check for stray `discrawl` processes before retry
|
||||
|
||||
1. Resolve scope: guild, channel, DM, author, keyword, date range.
|
||||
2. Check freshness for recent/current requests.
|
||||
3. Use CLI for normal reads; use SQL for precise counts/rankings.
|
||||
3. Prefer CLI search/messages for slices; use read-only SQL for exact counts.
|
||||
4. Report absolute date spans, counts, channel/DM names, and known gaps.
|
||||
|
||||
Common commands:
|
||||
@ -50,26 +73,52 @@ Common commands:
|
||||
```bash
|
||||
discrawl search "query"
|
||||
discrawl messages --channel '#maintainers' --days 7 --all
|
||||
discrawl --json sql "select count(*) from messages;"
|
||||
discrawl dms --last 20
|
||||
discrawl tui --dm
|
||||
discrawl sql "select count(*) from messages;"
|
||||
```
|
||||
|
||||
When the installed CLI lacks a new feature, build or run from `~/Projects/discrawl` before concluding the feature is missing.
|
||||
## SQL
|
||||
|
||||
## Discord DMs
|
||||
Use `discrawl sql` for exact counts, joins, and ranking queries when normal
|
||||
CLI reads are too coarse. The command is read-only by default, accepts SQL as
|
||||
args or stdin, and supports `--json` for agent parsing.
|
||||
|
||||
Wiretap/Desktop cache DMs are local-only. Do not imply they are in the published Git snapshot. For missing recent DMs, refresh first; stale archive is a common cause.
|
||||
Useful examples:
|
||||
|
||||
```bash
|
||||
discrawl --json sql "select count(*) as messages from messages;"
|
||||
discrawl --json sql "select coalesce(nullif(c.name, ''), m.channel_id) as channel, count(*) as messages from messages m left join channels c on c.id = m.channel_id group by m.channel_id order by messages desc limit 20;"
|
||||
discrawl --json sql "select coalesce(nullif(mm.display_name, ''), nullif(mm.global_name, ''), nullif(mm.username, ''), m.author_id) as author, count(*) as messages from messages m left join members mm on mm.guild_id = m.guild_id and mm.user_id = m.author_id group by m.guild_id, m.author_id order by messages desc limit 20;"
|
||||
```
|
||||
|
||||
Never use `--unsafe --confirm` unless the user explicitly asks for a database
|
||||
mutation and the write has been reviewed.
|
||||
|
||||
When the installed CLI lacks a new feature, build or run from a verified
|
||||
`openclaw/discrawl` checkout before concluding the feature is missing.
|
||||
|
||||
## Discord Boundaries
|
||||
|
||||
Bot API sync requires configured Discord bot credentials; do not invent token
|
||||
availability. Desktop wiretap mode reads local Discord Desktop artifacts and
|
||||
must not extract credentials, use user tokens, call Discord as the user, or
|
||||
write to Discord application storage. Wiretap/Desktop cache DMs are local-only
|
||||
and must not be described as part of the published Git snapshot. Git-share
|
||||
snapshots must not include secrets or `@me` DM rows.
|
||||
|
||||
## Verification
|
||||
|
||||
For repo edits, prefer existing Go gates:
|
||||
|
||||
```bash
|
||||
go test ./...
|
||||
GOWORK=off go test ./...
|
||||
```
|
||||
|
||||
Then run targeted CLI smoke for the touched surface, for example:
|
||||
|
||||
```bash
|
||||
discrawl doctor
|
||||
discrawl status --json
|
||||
discrawl search "test" --limit 5
|
||||
```
|
||||
|
||||
12
.editorconfig
Normal file
12
.editorconfig
Normal file
@ -0,0 +1,12 @@
|
||||
root = true
|
||||
|
||||
[*]
|
||||
charset = utf-8
|
||||
end_of_line = lf
|
||||
insert_final_newline = true
|
||||
indent_style = tab
|
||||
indent_size = 4
|
||||
|
||||
[*.{md,yml,yaml,json,toml}]
|
||||
indent_style = space
|
||||
indent_size = 2
|
||||
6
.gitattributes
vendored
Normal file
6
.gitattributes
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
* text=auto
|
||||
*.go text eol=lf
|
||||
*.md text eol=lf
|
||||
*.toml text eol=lf
|
||||
*.yml text eol=lf
|
||||
*.yaml text eol=lf
|
||||
12
.github/CODEOWNERS
vendored
Normal file
12
.github/CODEOWNERS
vendored
Normal file
@ -0,0 +1,12 @@
|
||||
# Protect ownership and automation rules.
|
||||
/.github/CODEOWNERS @openclaw/openclaw-secops
|
||||
/.github/dependabot.yml @openclaw/openclaw-secops
|
||||
/.github/workflows/ @openclaw/openclaw-secops
|
||||
|
||||
# Release, backup, and package integrity surfaces.
|
||||
/.goreleaser.yaml @openclaw/openclaw-secops
|
||||
/go.mod @openclaw/openclaw-secops
|
||||
/go.sum @openclaw/openclaw-secops
|
||||
/scripts/*backup* @openclaw/openclaw-secops
|
||||
/scripts/*release* @openclaw/openclaw-secops
|
||||
/scripts/*publish* @openclaw/openclaw-secops
|
||||
37
.github/workflows/codeql.yml
vendored
Normal file
37
.github/workflows/codeql.yml
vendored
Normal file
@ -0,0 +1,37 @@
|
||||
name: CodeQL
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
schedule:
|
||||
- cron: "29 4 * * 1"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
security-events: write
|
||||
|
||||
jobs:
|
||||
analyze:
|
||||
name: analyze
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v4
|
||||
with:
|
||||
languages: go
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v4
|
||||
59
.github/workflows/release.yml
vendored
59
.github/workflows/release.yml
vendored
@ -44,3 +44,62 @@ jobs:
|
||||
args: release --clean --config /tmp/.goreleaser.yaml
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
update-homebrew-tap:
|
||||
runs-on: ubuntu-latest
|
||||
needs: goreleaser
|
||||
steps:
|
||||
- name: Resolve release tag
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
|
||||
echo "RELEASE_TAG=${{ inputs.tag }}" >> "$GITHUB_ENV"
|
||||
else
|
||||
echo "RELEASE_TAG=${{ github.ref_name }}" >> "$GITHUB_ENV"
|
||||
fi
|
||||
|
||||
- name: Dispatch tap formula update
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.HOMEBREW_TAP_TOKEN }}
|
||||
run: |
|
||||
if [ -z "$GH_TOKEN" ]; then
|
||||
echo "::error::Set HOMEBREW_TAP_TOKEN with workflow access to steipete/homebrew-tap"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
request_id="discrawl-${RELEASE_TAG}-${GITHUB_RUN_ID}-${GITHUB_RUN_ATTEMPT}"
|
||||
expected_title="Update discrawl for ${RELEASE_TAG} (${request_id})"
|
||||
|
||||
gh workflow run update-formula.yml \
|
||||
--repo steipete/homebrew-tap \
|
||||
--ref main \
|
||||
-f formula=discrawl \
|
||||
-f tag="$RELEASE_TAG" \
|
||||
-f repository=openclaw/discrawl \
|
||||
-f artifact_template="{formula}_{version}_{target}.tar.gz" \
|
||||
-f request_id="$request_id"
|
||||
|
||||
run_id=""
|
||||
for _ in {1..30}; do
|
||||
run_id=$(gh run list \
|
||||
--repo steipete/homebrew-tap \
|
||||
--workflow update-formula.yml \
|
||||
--branch main \
|
||||
--event workflow_dispatch \
|
||||
--limit 20 \
|
||||
--json databaseId,displayTitle \
|
||||
--jq ".[] | select(.displayTitle == \"$expected_title\") | .databaseId" | head -n1)
|
||||
if [ -n "$run_id" ]; then
|
||||
break
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
|
||||
if [ -z "$run_id" ]; then
|
||||
echo "::error::Could not find tap workflow run with title: $expected_title"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
gh run watch "$run_id" \
|
||||
--repo steipete/homebrew-tap \
|
||||
--exit-status \
|
||||
--interval 10
|
||||
|
||||
63
.github/workflows/secret-scan.yml
vendored
Normal file
63
.github/workflows/secret-scan.yml
vendored
Normal file
@ -0,0 +1,63 @@
|
||||
name: "Security Gate: Secret Scanning"
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["**"]
|
||||
pull_request:
|
||||
branches: [main, master]
|
||||
|
||||
permissions: {}
|
||||
|
||||
jobs:
|
||||
trufflehog:
|
||||
name: Scan for Verified Secrets
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Resolve scan range
|
||||
id: scan_range
|
||||
env:
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
PR_BASE_SHA: ${{ github.event.pull_request.base.sha }}
|
||||
PR_HEAD_SHA: ${{ github.event.pull_request.head.sha }}
|
||||
PUSH_BASE_SHA: ${{ github.event.before }}
|
||||
PUSH_HEAD_SHA: ${{ github.sha }}
|
||||
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
zero_sha="0000000000000000000000000000000000000000"
|
||||
|
||||
if [[ "$EVENT_NAME" == "pull_request" ]]; then
|
||||
base="$PR_BASE_SHA"
|
||||
head="$PR_HEAD_SHA"
|
||||
else
|
||||
base="$PUSH_BASE_SHA"
|
||||
head="$PUSH_HEAD_SHA"
|
||||
if [[ -z "$base" || "$base" == "$zero_sha" ]]; then
|
||||
base="origin/$DEFAULT_BRANCH"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "base=$base" >> "$GITHUB_OUTPUT"
|
||||
echo "head=$head" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: TruffleHog OSS
|
||||
id: trufflehog
|
||||
uses: trufflesecurity/trufflehog@v3.95.2
|
||||
with:
|
||||
path: ./
|
||||
base: ${{ steps.scan_range.outputs.base }}
|
||||
head: ${{ steps.scan_range.outputs.head }}
|
||||
extra_args: --only-verified --debug
|
||||
|
||||
- name: Notify on failure
|
||||
if: steps.trufflehog.outcome == 'failure'
|
||||
run: |
|
||||
echo "::error::Verified secrets found. Rotate the credential before merging."
|
||||
exit 1
|
||||
86
.github/workflows/stale.yml
vendored
Normal file
86
.github/workflows/stale.yml
vendored
Normal file
@ -0,0 +1,86 @@
|
||||
name: Stale
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "25 4 * * *"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions: {}
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Mark stale unassigned issues and pull requests
|
||||
uses: actions/stale@v10
|
||||
with:
|
||||
days-before-issue-stale: 14
|
||||
days-before-issue-close: 7
|
||||
days-before-pr-stale: 14
|
||||
days-before-pr-close: 7
|
||||
stale-issue-label: stale
|
||||
stale-pr-label: stale
|
||||
exempt-issue-labels: enhancement,maintainer,pinned,security,no-stale
|
||||
exempt-pr-labels: maintainer,no-stale
|
||||
operations-per-run: 1000
|
||||
ascending: true
|
||||
exempt-all-assignees: true
|
||||
remove-stale-when-updated: true
|
||||
stale-issue-message: |
|
||||
This issue has been automatically marked as stale due to inactivity.
|
||||
Please add updated discrawl details or it will be closed.
|
||||
stale-pr-message: |
|
||||
This pull request has been automatically marked as stale due to inactivity.
|
||||
Please update it or it will be closed.
|
||||
close-issue-message: |
|
||||
Closing due to inactivity.
|
||||
If this still affects discrawl, open a new issue with current reproduction details.
|
||||
close-issue-reason: not_planned
|
||||
close-pr-message: |
|
||||
Closing due to inactivity.
|
||||
If this PR should be revived, reopen it with current context and validation.
|
||||
|
||||
- name: Mark stale assigned issues
|
||||
uses: actions/stale@v10
|
||||
with:
|
||||
days-before-issue-stale: 30
|
||||
days-before-issue-close: 10
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
stale-issue-label: stale
|
||||
exempt-issue-labels: enhancement,maintainer,pinned,security,no-stale
|
||||
operations-per-run: 1000
|
||||
ascending: true
|
||||
include-only-assigned: true
|
||||
remove-stale-when-updated: true
|
||||
stale-issue-message: |
|
||||
This assigned issue has been automatically marked as stale after 30 days of inactivity.
|
||||
Please add an update or it will be closed.
|
||||
close-issue-message: |
|
||||
Closing due to inactivity.
|
||||
If this still affects discrawl, reopen or file a new issue with current evidence.
|
||||
close-issue-reason: not_planned
|
||||
|
||||
- name: Mark stale assigned pull requests
|
||||
uses: actions/stale@v10
|
||||
with:
|
||||
days-before-issue-stale: -1
|
||||
days-before-issue-close: -1
|
||||
days-before-pr-stale: 27
|
||||
days-before-pr-close: 7
|
||||
stale-pr-label: stale
|
||||
exempt-pr-labels: maintainer,no-stale
|
||||
operations-per-run: 1000
|
||||
ascending: true
|
||||
include-only-assigned: true
|
||||
ignore-pr-updates: true
|
||||
remove-stale-when-updated: true
|
||||
stale-pr-message: |
|
||||
This assigned pull request has been automatically marked as stale after being open for 27 days.
|
||||
Please add an update or it will be closed.
|
||||
close-pr-message: |
|
||||
Closing due to inactivity.
|
||||
If this PR should be revived, reopen it with current context and validation.
|
||||
54
CHANGELOG.md
54
CHANGELOG.md
@ -1,21 +1,41 @@
|
||||
# Changelog
|
||||
|
||||
## 0.7.0 - Unreleased
|
||||
## 0.7.0 - 2026-05-08
|
||||
|
||||
### Changes
|
||||
|
||||
- Document the crawlkit-backed config/status/control, snapshot, mirror,
|
||||
sync-state, output, and shared TUI surfaces now used on `main`.
|
||||
- Clarify that Discord bot sync, desktop wiretap parsing, DM privacy filters,
|
||||
schema ownership, FTS/ranking, embeddings, and analytics remain app-owned.
|
||||
- Align terminal browser docs with the gitcrawl-style shared TUI model:
|
||||
channel/person/thread groups, message rows, detail/thread panes, sorting,
|
||||
mouse selection, right-click actions, and local/remote status chrome.
|
||||
- Added `discrawl tui`, a terminal archive browser for stored guild messages and local `@me` wiretap DMs using the shared crawlkit pane browser.
|
||||
- Added crawlkit-backed `metadata --json`, `status --json`, and `doctor --json` control surfaces for launchers, automation, and CI checks.
|
||||
- Published the generated documentation site at `discrawl.sh`, including command pages, install/setup docs, configuration, security notes, guides, a contact page, and social cards.
|
||||
- Moved the Go module and release metadata to `github.com/openclaw/discrawl`.
|
||||
|
||||
### Fixes
|
||||
|
||||
- Kept documented command-local search flags working after the query, such as `discrawl search "term" --limit 5`. Thanks @PrinceOfEgypt.
|
||||
- Made the terminal browser more useful and accurate: default guild scoping, newest-message startup, compact panes, selected-message detail panes, count-header sorting, local/remote status labels, right-click actions, Discord message URLs, row labels, direct-message pane labels, mention rendering, inline mention resolution, attachment details, and reply-context hydration without broad thread scans.
|
||||
- Kept read-only commands such as `search`, `messages`, and safe `sql` usable while `tail` or another writer holds the sync lock. Thanks @PrinceOfEgypt.
|
||||
- Kept `tui --help`, status, and terminal-browser reads safe for fresh or missing local databases without triggering Git snapshot auto-update.
|
||||
- Kept local-only snapshot rows filtered during shared archive imports and forwarded snapshot import progress through the crawlkit import path.
|
||||
- Made stale Git snapshot imports plan shard deltas from crawlkit file fingerprints or Git object identity, so routine shared-archive refreshes import changed message tail shards instead of rebuilding every table and FTS index.
|
||||
- Included progress percentages in message-sync logs.
|
||||
- Fixed GoReleaser version stamping after the module path move.
|
||||
|
||||
### Documentation
|
||||
|
||||
- Documented the crawlkit-backed config/status/control, snapshot, mirror, sync-state, output, and shared TUI surfaces now used on `main`.
|
||||
- Clarified that Discord bot sync, desktop wiretap parsing, DM privacy filters, schema ownership, FTS/ranking, embeddings, and analytics remain app-owned.
|
||||
- Aligned terminal-browser docs with the gitcrawl-style shared TUI model: channel/person/thread groups, message rows, detail/thread panes, sorting, mouse selection, right-click actions, and local/remote status chrome.
|
||||
- Refreshed the repo-local `discrawl` agent skill for local Discord archive, freshness, query, boundary, TUI, verification, and read-only SQL workflows.
|
||||
|
||||
### Maintenance
|
||||
|
||||
- Document the read-only `metadata --json`, `status --json`, and
|
||||
`doctor --json` control surface for launchers, automation, and CI checks.
|
||||
- Migrated runtime paths, SQLite opening, archive mirror/export/import helpers, output/status wiring, and TUI plumbing onto the shared `crawlkit` infrastructure.
|
||||
- Moved reusable embedding providers and vector helpers onto `crawlkit` while keeping Discrawl-owned storage, FTS, queueing, and privacy filters local.
|
||||
- Updated crawlkit through `v0.4.1`, switched imports to `github.com/openclaw/crawlkit`, and added CI smoke coverage for the crawlkit control surface and merge behavior.
|
||||
- Added CodeQL, verified secret scanning, protected automation owners, stale issue automation, `.editorconfig`, and `.gitattributes`.
|
||||
- Added release workflow automation that dispatches the Homebrew tap formula update after GoReleaser publishes a tag.
|
||||
|
||||
## 0.6.6 - 2026-05-05
|
||||
|
||||
### Fixes
|
||||
|
||||
@ -41,24 +61,10 @@
|
||||
- Refreshed dependency and CI tooling pins, including GoReleaser, `go-toml`, golangci-lint, and gosec.
|
||||
- Tightened CI compatibility with the latest linters and made signal-cancellation and sync fixture tests deterministic under the race detector.
|
||||
|
||||
### Fixes
|
||||
|
||||
- Label direct-message TUI panes as direct messages instead of raw `@me` guild rows, keeping DM channel/person context readable.
|
||||
- Inherit shared crawlkit TUI improvements for newest-first startup, count-header sorting, selected-message-first chat detail panes, and gitcrawl-style metadata labels.
|
||||
- Surface Discord attachment filenames and extracted text in TUI detail panes instead of only showing `attachments=true`.
|
||||
|
||||
## 0.6.3 - 2026-05-01
|
||||
|
||||
### Changes
|
||||
|
||||
- Add crawlkit control metadata/status surfaces with `metadata --json`, `status --json`, and `doctor --json`.
|
||||
- Add `tap` and `cache-import` as public desktop-cache import names while keeping `wiretap` as a documented legacy alias.
|
||||
- Add `discrawl tui`, a terminal archive browser for stored guild messages and local `@me` wiretap DMs using the shared `crawlkit/tui` package.
|
||||
- Render TUI rows with compact panes and expose pinned, attachment, reply, channel, and author metadata in the detail pane.
|
||||
|
||||
### Fixes
|
||||
|
||||
- Keep status and TUI reads safe for fresh or missing local databases without triggering git-share auto-update.
|
||||
- Added OS keyring fallback for Discord bot-token resolution, keeping env as the first source and documenting the default keyring item. (#17)
|
||||
- Clarified and locked down FTS query normalization so operator-like search terms such as `AND`, `OR`, `NOT`, `NEAR`, and `*` stay parameterized and quoted before SQLite `MATCH`. Thanks @mvanhorn.
|
||||
|
||||
|
||||
@ -177,7 +177,9 @@ The terminal browser uses the shared crawlkit explorer. The left pane groups
|
||||
channels, people, or threads; the middle pane lists messages; the right pane
|
||||
shows the selected message, surrounding conversation, and thread detail. Mouse
|
||||
selection, right-click actions, sortable headers, and the local/remote footer
|
||||
follow the same interaction model as `gitcrawl tui`.
|
||||
follow the same interaction model as `gitcrawl tui`. See
|
||||
[`docs/commands/tui.md`](docs/commands/tui.md) for flags and read-only/DM scope
|
||||
notes.
|
||||
|
||||
### `init`
|
||||
|
||||
@ -247,6 +249,7 @@ When `--channels` includes a forum channel id, `discrawl` expands that forum's t
|
||||
Long runs now emit periodic progress logs to stderr so large backfills and Git snapshot imports do not look hung.
|
||||
If in-flight channels stop completing for a while, `discrawl` now emits `message sync waiting` heartbeat logs with the oldest active channel, per-channel page activity, and skip/defer counters, and every run ends with a `message sync finished` summary.
|
||||
Each channel crawl also has a bounded runtime budget, so a pathological channel is deferred and retried on the next sync instead of pinning a worker forever.
|
||||
Retryable failures and unavailable-channel markers are tracked per channel; stale unavailable markers are cleared after a later successful crawl, and marker cleanup is best-effort so one missing local sync-state row cannot crash the run.
|
||||
Full sync member refresh is best-effort and currently gives up after five minutes without a caller-supplied deadline, so message sync completion is not held hostage by a slow guild member crawl.
|
||||
When the archive is already complete, `sync --full` now reuses the stored backlog markers and limits steady-state refresh to live top-level channels plus active threads instead of revisiting every stored archived thread.
|
||||
If a guild already has a local member snapshot, routine syncs reuse it and skip another full member crawl until that snapshot ages out.
|
||||
@ -482,9 +485,9 @@ discrawl subscribe --stale-after 15m https://github.com/example/discord-archive.
|
||||
discrawl subscribe --no-auto-update https://github.com/example/discord-archive.git
|
||||
```
|
||||
|
||||
Once `share.remote` is configured, read commands auto-fetch and import when the local share import is older than `share.stale_after` (default `15m`). `discrawl update` forces the same pull/import step manually. `discrawl sync` does not auto-import the share unless `--update=auto` or `--update=force` is provided, so routine live refreshes stay fast.
|
||||
Once `share.remote` is configured, read commands auto-fetch and import when the local share import is older than `share.stale_after` (default `15m`). Imports are planned from crawlkit shard fingerprints, with a Git-object fallback for older manifests, so routine updates normally read only changed tail shards and preserve local FTS rows instead of rebuilding the whole archive. `discrawl update` forces the same pull/import step manually. `discrawl sync` does not auto-import the share unless `--update=auto` or `--update=force` is provided, so routine live refreshes stay fast.
|
||||
|
||||
Hybrid mode is supported too: keep normal Discord credentials configured and set `share.remote`. `discrawl sync --update=auto` and `discrawl messages --sync` import the Git snapshot first, then use live Discord for latest-message deltas. Use `sync --all-channels` or `sync --full` when you intentionally want a broader live repair/backfill pass.
|
||||
Hybrid mode is supported too: keep normal Discord credentials configured and set `share.remote`. `discrawl sync --update=auto` and `discrawl messages --sync` import the Git snapshot first, usually as a changed-shard delta, then use live Discord for latest-message deltas. Use `sync --all-channels` or `sync --full` when you intentionally want a broader live repair/backfill pass.
|
||||
|
||||
Git snapshots publish non-DM archive tables by default. Embedding queue state stays local to each machine, and Git-only readers can use FTS immediately without an embedding provider.
|
||||
|
||||
|
||||
@ -31,9 +31,9 @@ discrawl search "panic: nil pointer"
|
||||
discrawl tail
|
||||
```
|
||||
|
||||
`discrawl tui` uses the shared crawlkit terminal explorer: channel/person/thread
|
||||
groups on the left, message rows in the middle, and readable message/thread
|
||||
detail on the right.
|
||||
[`discrawl tui`](commands/tui.html) uses the shared crawlkit terminal explorer:
|
||||
channel/person/thread groups on the left, message rows in the middle, and
|
||||
readable message/thread detail on the right.
|
||||
|
||||
## Sections
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ By default, `sync` runs both live/local sources and does **not** import the Git
|
||||
- Discord bot-token sync for bot-visible guild data
|
||||
- local Discord Desktop cache import for classifiable cached messages and proven DMs
|
||||
|
||||
Use [`update`](update.html) when you want to pull/import the shared Git snapshot. If you intentionally want a sync run to import the snapshot before live deltas, pass `--update=auto` (only when stale) or `--update=force` (always). `--no-update` is accepted as an explicit no-op alias for the default.
|
||||
Use [`update`](update.html) when you want to pull/import the shared Git snapshot. Snapshot imports normally use changed-shard deltas, but unsafe table changes fall back to a full import. If you intentionally want a sync run to import the snapshot before live deltas, pass `--update=auto` (only when stale) or `--update=force` (always). `--no-update` is accepted as an explicit no-op alias for the default.
|
||||
|
||||
Run one explicit `--full` pass when you want a complete historical guild archive. Use plain `sync` afterward for frequent latest-message and desktop-cache refreshes.
|
||||
|
||||
@ -70,6 +70,8 @@ discrawl sync --with-embeddings
|
||||
- Heartbeat logs (`message sync waiting`) name the oldest active channel and per-channel page activity if in-flight channels stop completing for a while.
|
||||
- Every run ends with a `message sync finished` summary.
|
||||
- Each channel crawl has a bounded runtime budget; pathological channels are deferred and retried next sync.
|
||||
- Retryable failures and unavailable-channel markers are tracked per channel; stale unavailable markers are cleared after a later successful crawl.
|
||||
- Marker cleanup is best-effort, so one missing local sync-state row cannot crash the run.
|
||||
- Full sync member refresh is best-effort and gives up after five minutes without a caller-supplied deadline.
|
||||
- When the archive is already complete, `sync --full` reuses backlog markers and limits steady-state refresh to live top-level channels plus active threads.
|
||||
|
||||
|
||||
47
docs/commands/tui.md
Normal file
47
docs/commands/tui.md
Normal file
@ -0,0 +1,47 @@
|
||||
# `tui`
|
||||
|
||||
Opens the local terminal archive browser for stored messages.
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
discrawl tui
|
||||
discrawl tui --guild 123456789012345678 --channel general
|
||||
discrawl tui --guilds 123,456 --author 1456464433768300635
|
||||
discrawl tui --dm
|
||||
discrawl --json tui --limit 50
|
||||
```
|
||||
|
||||
## What it shows
|
||||
|
||||
The browser uses the shared crawlkit explorer:
|
||||
|
||||
- left pane: channel, person, or thread groups
|
||||
- middle pane: newest matching message rows
|
||||
- right pane: selected message detail, attachments, replies, and thread context
|
||||
- footer: local DB or remote Git snapshot source
|
||||
|
||||
Mouse selection, right-click actions, sortable headers, refresh, and chat layout match the other crawlkit-backed archive tools.
|
||||
|
||||
## Flags
|
||||
|
||||
- `--guild <id>` / `--guilds <id,id>` - restrict the guild scope
|
||||
- `--dm` - browse local direct messages under the synthetic `@me` guild
|
||||
- `--channel <id|name|#name>` - restrict to one channel or DM conversation
|
||||
- `--author <id|name>` - restrict to one author
|
||||
- `--limit <n>` - newest rows to load (default 200)
|
||||
- `--include-empty` - include rows with no displayable/searchable content
|
||||
- `--json` - print crawlkit browser rows as JSON instead of opening the TUI
|
||||
|
||||
## Notes
|
||||
|
||||
- `tui` is read-only.
|
||||
- without `--guild`, `--guilds`, or `--dm`, it uses `default_guild_id` when configured; otherwise it can browse all stored guild rows
|
||||
- `--dm` only shows messages imported from the local Discord Desktop cache by [`wiretap`](wiretap.html)
|
||||
- `--json` is useful for launchers and agents that want the same row shape without an interactive terminal
|
||||
|
||||
## See also
|
||||
|
||||
- [`messages`](messages.html)
|
||||
- [`dms`](dms.html)
|
||||
- [`wiretap`](wiretap.html)
|
||||
@ -2,6 +2,8 @@
|
||||
|
||||
Forces a Git snapshot pull and import.
|
||||
|
||||
Routine imports are delta-planned from crawlkit shard fingerprints, with a Git-object fallback for older manifests. The usual publish only imports changed tail shards; unsafe table changes fall back to a full import.
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
@ -19,7 +21,7 @@ discrawl update --with-embeddings
|
||||
|
||||
## When to use it
|
||||
|
||||
- you have `share.remote` configured and want a fresh import before running a command that does not auto-update (`sync` does not auto-import unless `--update=auto` is passed)
|
||||
- you have `share.remote` configured and want a fresh shard-delta import before running a command that does not auto-update (`sync` does not auto-import unless `--update=auto` is passed)
|
||||
- you set `--no-auto-update` when subscribing and want to refresh on demand
|
||||
- a CI job already imported the latest snapshot but read commands still consider it stale
|
||||
|
||||
|
||||
@ -35,7 +35,7 @@ discrawl subscribe --stale-after 15m https://github.com/example/discord-archive.
|
||||
discrawl subscribe --no-auto-update https://github.com/example/discord-archive.git
|
||||
```
|
||||
|
||||
`discrawl update` forces the same pull/import step manually.
|
||||
`discrawl update` forces the same pull/import step manually. Snapshot imports are delta-planned from crawlkit shard fingerprints. Older manifests without those fields fall back to Git blob identity, so the common publish shape only imports the changed message tail shard plus small cursor tables. Unsafe table-shape changes still fall back to a full import.
|
||||
|
||||
`discrawl sync` does **not** auto-import the share unless `--update=auto` or `--update=force` is provided, so routine live refreshes stay fast.
|
||||
|
||||
@ -44,7 +44,7 @@ discrawl subscribe --no-auto-update https://github.com/example/discord-archive.g
|
||||
Keep normal Discord credentials configured **and** set `share.remote`:
|
||||
|
||||
```bash
|
||||
discrawl sync --update=auto # import snapshot first, then live deltas
|
||||
discrawl sync --update=auto # import snapshot delta first, then live deltas
|
||||
discrawl messages --sync # blocking pre-query sync for matched scope
|
||||
discrawl sync --all-channels # broader live repair
|
||||
discrawl sync --full # historical backfill
|
||||
|
||||
@ -19,7 +19,7 @@ Sync modes control the Discord bot API side of a run. When `wiretap` is selected
|
||||
| Command | Use when | Behavior |
|
||||
| --- | --- | --- |
|
||||
| `discrawl sync` | routine refresh | skips member refreshes, checks live top-level channels plus active threads, only fetches new messages for channels with a stored latest cursor |
|
||||
| `discrawl sync --update=auto` | hybrid Git/live refresh | imports a stale Git snapshot first, then runs the routine live refresh |
|
||||
| `discrawl sync --update=auto` | hybrid Git/live refresh | imports a stale Git snapshot first, usually as a changed-shard delta, then runs the routine live refresh |
|
||||
| `discrawl sync --all-channels` | repair pass | broad incremental sweep across every stored channel/thread, including archived threads |
|
||||
| `discrawl sync --full` | historical backfill | crawls older history until channels are complete; can take a long time on large servers |
|
||||
|
||||
@ -43,6 +43,8 @@ Run one explicit `--full` pass when you want a complete historical guild archive
|
||||
- If in-flight channels stop completing for a while, `discrawl` emits `message sync waiting` heartbeat logs with the oldest active channel, per-channel page activity, and skip/defer counters.
|
||||
- Every run ends with a `message sync finished` summary.
|
||||
- Each channel crawl has a bounded runtime budget; pathological channels are deferred and retried on the next sync.
|
||||
- Retryable failures and unavailable-channel markers are tracked per channel; stale unavailable markers are cleared after a later successful crawl.
|
||||
- Marker cleanup is best-effort, so one missing local sync-state row cannot crash the run.
|
||||
- Full sync member refresh is best-effort and gives up after five minutes without a caller-supplied deadline, so message sync completion is not held hostage by a slow guild member crawl.
|
||||
- When the archive is already complete, `sync --full` reuses backlog markers and limits steady-state refresh to live top-level channels plus active threads instead of revisiting every stored archived thread.
|
||||
- If a guild already has a local member snapshot, routine syncs reuse it and skip another full member crawl until that snapshot ages out.
|
||||
|
||||
15
go.mod
15
go.mod
@ -1,6 +1,6 @@
|
||||
module github.com/openclaw/discrawl
|
||||
|
||||
go 1.26.2
|
||||
go 1.26.3
|
||||
|
||||
require (
|
||||
github.com/bwmarrin/discordgo v0.29.0
|
||||
@ -13,9 +13,8 @@ require (
|
||||
|
||||
require (
|
||||
github.com/charmbracelet/bubbles v1.0.0 // indirect
|
||||
github.com/clipperhouse/displaywidth v0.9.0 // indirect
|
||||
github.com/clipperhouse/stringish v0.1.1 // indirect
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0 // indirect
|
||||
github.com/clipperhouse/displaywidth v0.11.0 // indirect
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.3.1 // indirect
|
||||
modernc.org/sqlite v1.50.0 // indirect
|
||||
)
|
||||
@ -25,7 +24,7 @@ require (
|
||||
github.com/charmbracelet/bubbletea v1.3.10 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.4.1 // indirect
|
||||
github.com/charmbracelet/lipgloss v1.1.0 // indirect
|
||||
github.com/charmbracelet/x/ansi v0.11.6 // indirect
|
||||
github.com/charmbracelet/x/ansi v0.11.7 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||
github.com/charmbracelet/x/term v0.2.2 // indirect
|
||||
github.com/danieljoos/wincred v1.2.3 // indirect
|
||||
@ -36,18 +35,18 @@ require (
|
||||
github.com/google/pprof v0.0.0-20260402051712-545e8a4df936 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/kr/pretty v0.3.1 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.22 // indirect
|
||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.19 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.23 // indirect
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/openclaw/crawlkit v0.5.0
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/vincentkoc/crawlkit v0.4.0
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
golang.org/x/crypto v0.50.0 // indirect
|
||||
golang.org/x/tools v0.44.0 // indirect
|
||||
|
||||
26
go.sum
26
go.sum
@ -10,18 +10,16 @@ github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco
|
||||
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
|
||||
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
||||
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
||||
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
|
||||
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
|
||||
github.com/charmbracelet/x/ansi v0.11.7 h1:kzv1kJvjg2S3r9KHo8hDdHFQLEqn4RBCb39dAYC84jI=
|
||||
github.com/charmbracelet/x/ansi v0.11.7/go.mod h1:9qGpnAVYz+8ACONkZBUWPtL7lulP9No6p1epAihUZwQ=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
|
||||
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
|
||||
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
|
||||
github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA=
|
||||
github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA=
|
||||
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
|
||||
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
|
||||
github.com/clipperhouse/displaywidth v0.11.0 h1:lBc6kY44VFw+TDx4I8opi/EtL9m20WSEFgwIwO+UVM8=
|
||||
github.com/clipperhouse/displaywidth v0.11.0/go.mod h1:bkrFNkf81G8HyVqmKGxsPufD3JhNl3dSqnGhOoSD/o0=
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJv2v7Vk=
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/danieljoos/wincred v1.2.3 h1:v7dZC2x32Ut3nEfRH+vhoZGvN72+dQ/snVXo/vMFLdQ=
|
||||
github.com/danieljoos/wincred v1.2.3/go.mod h1:6qqX0WNrS4RzPZ1tnroDzq9kY3fu1KwE7MRLQK4X0bs=
|
||||
@ -49,14 +47,14 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0 h1:UtrWVfLdarDgc44HcS7pYloGHJUjHV/4FwW4TvVgFr4=
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mattn/go-isatty v0.0.22 h1:j8l17JJ9i6VGPUFUYoTUKPSgKe/83EYU2zBC7YNKMw4=
|
||||
github.com/mattn/go-isatty v0.0.22/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4=
|
||||
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
|
||||
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3RybWcw=
|
||||
github.com/mattn/go-runewidth v0.0.23/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
@ -65,6 +63,8 @@ github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/openclaw/crawlkit v0.5.0 h1:sVqIbQ5v6LiOf+NXcVj93UhfoaJqMbBlrd1lU6uhO9M=
|
||||
github.com/openclaw/crawlkit v0.5.0/go.mod h1:/AI8o/DeRqXPZJPHq/9mGUjNzLPskm/wTjikRPxEdHY=
|
||||
github.com/pelletier/go-toml/v2 v2.3.1 h1:MYEvvGnQjeNkRF1qUuGolNtNExTDwct51yp7olPtrEc=
|
||||
github.com/pelletier/go-toml/v2 v2.3.1/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
@ -80,8 +80,6 @@ github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/vincentkoc/crawlkit v0.4.0 h1:1jQZAYbBivy6d7ewNdMZ8THgmJVwb+pQT0kH5Z9COHI=
|
||||
github.com/vincentkoc/crawlkit v0.4.0/go.mod h1:/ioLA/tyZ/927kAOGg0M8Mrqk7pnTZLpCKWfpul9zoE=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/zalando/go-keyring v0.2.8 h1:6sD/Ucpl7jNq10rM2pgqTs0sZ9V3qMrqfIIy5YPccHs=
|
||||
|
||||
@ -13,10 +13,10 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/openclaw/crawlkit/embed"
|
||||
"github.com/openclaw/discrawl/internal/config"
|
||||
"github.com/openclaw/discrawl/internal/discord"
|
||||
"github.com/openclaw/discrawl/internal/discorddesktop"
|
||||
"github.com/openclaw/discrawl/internal/embed"
|
||||
"github.com/openclaw/discrawl/internal/share"
|
||||
"github.com/openclaw/discrawl/internal/store"
|
||||
"github.com/openclaw/discrawl/internal/syncer"
|
||||
@ -374,7 +374,7 @@ func (r *runtime) runEmbed(args []string) error {
|
||||
providerFactory := r.newEmbed
|
||||
if providerFactory == nil {
|
||||
providerFactory = func(cfg config.EmbeddingsConfig) (embed.Provider, error) {
|
||||
return embed.NewProvider(cfg)
|
||||
return embed.NewProvider(crawlkitEmbeddingConfig(cfg))
|
||||
}
|
||||
}
|
||||
provider, err := providerFactory(r.cfg.Search.Embeddings)
|
||||
@ -435,7 +435,7 @@ func (r *runtime) runDoctor(args []string) error {
|
||||
report["share_stale_after"] = cfg.Share.StaleAfter
|
||||
}
|
||||
if cfg.Search.Embeddings.Enabled {
|
||||
check := embed.CheckProvider(r.ctx, cfg.Search.Embeddings)
|
||||
check := embed.CheckProvider(r.ctx, crawlkitEmbeddingConfig(cfg.Search.Embeddings))
|
||||
report["embeddings"] = check.Status
|
||||
report["embeddings_provider"] = check.Provider
|
||||
report["embeddings_model"] = check.Model
|
||||
|
||||
@ -20,11 +20,11 @@ func (r *runtime) runAnalytics(args []string) error {
|
||||
subArgs := args[1:]
|
||||
switch subcommand {
|
||||
case "quiet":
|
||||
return r.withLocalStoreDefaultLocked(true, true, func() error {
|
||||
return r.withLocalStoreRead(true, func() error {
|
||||
return r.runAnalyticsQuiet(subArgs)
|
||||
})
|
||||
case "trends":
|
||||
return r.withLocalStoreDefaultLocked(true, true, func() error {
|
||||
return r.withLocalStoreRead(true, func() error {
|
||||
return r.runAnalyticsTrends(subArgs)
|
||||
})
|
||||
default:
|
||||
|
||||
@ -11,9 +11,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/openclaw/crawlkit/embed"
|
||||
"github.com/openclaw/discrawl/internal/config"
|
||||
"github.com/openclaw/discrawl/internal/discord"
|
||||
"github.com/openclaw/discrawl/internal/embed"
|
||||
"github.com/openclaw/discrawl/internal/share"
|
||||
"github.com/openclaw/discrawl/internal/store"
|
||||
"github.com/openclaw/discrawl/internal/syncer"
|
||||
@ -118,6 +118,17 @@ type runtime struct {
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
func crawlkitEmbeddingConfig(cfg config.EmbeddingsConfig) embed.Config {
|
||||
return embed.Config{
|
||||
Provider: cfg.Provider,
|
||||
Model: cfg.Model,
|
||||
BaseURL: cfg.BaseURL,
|
||||
APIKeyEnv: cfg.APIKeyEnv,
|
||||
RequestTimeout: cfg.RequestTimeout,
|
||||
MaxInputChars: cfg.MaxInputChars,
|
||||
}
|
||||
}
|
||||
|
||||
type discordClient interface {
|
||||
syncer.Client
|
||||
Close() error
|
||||
@ -155,7 +166,7 @@ func (r *runtime) dispatch(rest []string) error {
|
||||
return r.withLocalStoreLocked(false, func() error { return r.runWiretap(rest[1:]) })
|
||||
case "search":
|
||||
autoShareUpdate := !hasBoolFlag(rest[1:], "--dm")
|
||||
return r.withLocalStoreDefaultLocked(autoShareUpdate, autoShareUpdate, func() error { return r.runSearch(rest[1:]) })
|
||||
return r.withLocalStoreRead(autoShareUpdate, func() error { return r.runSearch(rest[1:]) })
|
||||
case "tui":
|
||||
if hasHelpArg(rest[1:]) {
|
||||
return r.runTUI(rest[1:])
|
||||
@ -166,27 +177,30 @@ func (r *runtime) dispatch(rest []string) error {
|
||||
return r.withServicesAutoLocked(true, true, true, func() error { return r.runMessages(rest[1:]) })
|
||||
}
|
||||
autoShareUpdate := !hasBoolFlag(rest[1:], "--dm")
|
||||
return r.withLocalStoreDefaultLocked(autoShareUpdate, autoShareUpdate, func() error { return r.runMessages(rest[1:]) })
|
||||
return r.withLocalStoreRead(autoShareUpdate, func() error { return r.runMessages(rest[1:]) })
|
||||
case "digest":
|
||||
return r.withLocalStoreDefaultLocked(true, true, func() error { return r.runDigest(rest[1:]) })
|
||||
return r.withLocalStoreRead(true, func() error { return r.runDigest(rest[1:]) })
|
||||
case "analytics":
|
||||
return r.runAnalytics(rest[1:])
|
||||
case "dms":
|
||||
return r.withLocalStoreDefault(false, func() error { return r.runDirectMessages(rest[1:]) })
|
||||
return r.withLocalStoreRead(false, func() error { return r.runDirectMessages(rest[1:]) })
|
||||
case "mentions":
|
||||
return r.withLocalStoreLocked(true, func() error { return r.runMentions(rest[1:]) })
|
||||
return r.withLocalStoreRead(true, func() error { return r.runMentions(rest[1:]) })
|
||||
case "embed":
|
||||
return r.withLocalStoreLocked(true, func() error { return r.runEmbed(rest[1:]) })
|
||||
case "sql":
|
||||
return r.withLocalStoreLocked(true, func() error { return r.runSQL(rest[1:]) })
|
||||
if boolFlagEnabled(rest[1:], "--unsafe") {
|
||||
return r.withLocalStoreLocked(true, func() error { return r.runSQL(rest[1:]) })
|
||||
}
|
||||
return r.withLocalStoreRead(true, func() error { return r.runSQL(rest[1:]) })
|
||||
case "members":
|
||||
return r.withLocalStoreLocked(true, func() error { return r.runMembers(rest[1:]) })
|
||||
return r.withLocalStoreRead(true, func() error { return r.runMembers(rest[1:]) })
|
||||
case "channels":
|
||||
return r.withLocalStoreLocked(true, func() error { return r.runChannels(rest[1:]) })
|
||||
return r.withLocalStoreRead(true, func() error { return r.runChannels(rest[1:]) })
|
||||
case "status":
|
||||
return r.withLocalStoreReadOnly(func() error { return r.runStatus(rest[1:]) })
|
||||
case "report":
|
||||
return r.withLocalStoreLocked(true, func() error { return r.runReport(rest[1:]) })
|
||||
return r.withLocalStoreRead(true, func() error { return r.runReport(rest[1:]) })
|
||||
case "publish":
|
||||
return r.withServicesAutoLocked(false, false, true, func() error { return r.runPublish(rest[1:]) })
|
||||
case "subscribe":
|
||||
@ -212,12 +226,35 @@ func (r *runtime) withLocalStoreLocked(autoShareUpdate bool, fn func() error) er
|
||||
return r.withLocalStoreUpdateLocked(boolShareUpdateMode(autoShareUpdate), true, fn)
|
||||
}
|
||||
|
||||
func (r *runtime) withLocalStoreDefault(autoShareUpdate bool, fn func() error) error {
|
||||
return r.withLocalStoreUpdateLocked(boolShareUpdateMode(autoShareUpdate), false, fn)
|
||||
func (r *runtime) withLocalStoreRead(autoShareUpdate bool, fn func() error) error {
|
||||
return r.withLocalStoreReadUpdate(boolShareUpdateMode(autoShareUpdate), fn)
|
||||
}
|
||||
|
||||
func (r *runtime) withLocalStoreDefaultLocked(autoShareUpdate, lockDB bool, fn func() error) error {
|
||||
return r.withLocalStoreUpdateLocked(boolShareUpdateMode(autoShareUpdate), lockDB, fn)
|
||||
func (r *runtime) withLocalStoreReadUpdate(updateMode shareUpdateMode, fn func() error) error {
|
||||
cfg, err := config.Load(r.configPath)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
return configErr(err)
|
||||
}
|
||||
cfg = config.Default()
|
||||
if err := cfg.Normalize(); err != nil {
|
||||
return configErr(err)
|
||||
}
|
||||
}
|
||||
if err := config.EnsureRuntimeDirs(cfg); err != nil {
|
||||
return configErr(err)
|
||||
}
|
||||
dbPath, err := config.ExpandPath(cfg.DBPath)
|
||||
if err != nil {
|
||||
return configErr(err)
|
||||
}
|
||||
r.cfg = cfg
|
||||
if r.shouldAutoUpdateShare(updateMode) {
|
||||
if err := r.autoUpdateShareIfLockAvailable(dbPath, updateMode); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return r.openLocalStoreReadOnly(dbPath, fn)
|
||||
}
|
||||
|
||||
func (r *runtime) withLocalStoreUpdateLocked(updateMode shareUpdateMode, lockDB bool, fn func() error) error {
|
||||
@ -247,6 +284,38 @@ func (r *runtime) withLocalStoreUpdateLocked(updateMode shareUpdateMode, lockDB
|
||||
return r.openLocalStore(dbPath, updateMode, fn)
|
||||
}
|
||||
|
||||
func (r *runtime) shouldAutoUpdateShare(mode shareUpdateMode) bool {
|
||||
return os.Getenv("DISCRAWL_NO_AUTO_UPDATE") != "1" &&
|
||||
r.cfg.ShareEnabled() &&
|
||||
(mode == shareUpdateForce || mode == shareUpdateAuto || (mode == shareUpdateConfigured && r.cfg.Share.AutoUpdate))
|
||||
}
|
||||
|
||||
func (r *runtime) autoUpdateShareIfLockAvailable(dbPath string, updateMode shareUpdateMode) error {
|
||||
locked, err := r.tryWithSyncLock(func() error {
|
||||
storeFactory := r.openStore
|
||||
if storeFactory == nil {
|
||||
storeFactory = store.Open
|
||||
}
|
||||
var openErr error
|
||||
r.store, openErr = storeFactory(r.ctx, dbPath)
|
||||
if openErr != nil {
|
||||
return dbErr(openErr)
|
||||
}
|
||||
defer func() {
|
||||
_ = r.store.Close()
|
||||
r.store = nil
|
||||
}()
|
||||
return r.autoUpdateShare(updateMode)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !locked {
|
||||
r.logger.Info("share update skipped; sync lock is held")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *runtime) openLocalStore(dbPath string, updateMode shareUpdateMode, fn func() error) error {
|
||||
storeFactory := r.openStore
|
||||
if storeFactory == nil {
|
||||
@ -295,6 +364,50 @@ func (r *runtime) withLocalStoreReadOnly(fn func() error) error {
|
||||
return fn()
|
||||
}
|
||||
|
||||
func (r *runtime) openLocalStoreReadOnly(dbPath string, fn func() error) error {
|
||||
var openErr error
|
||||
r.store, openErr = store.OpenReadOnly(r.ctx, dbPath)
|
||||
if openErr != nil {
|
||||
if errors.Is(openErr, os.ErrNotExist) {
|
||||
storeFactory := r.openStore
|
||||
if storeFactory == nil {
|
||||
storeFactory = store.Open
|
||||
}
|
||||
r.store, openErr = storeFactory(r.ctx, dbPath)
|
||||
if openErr == nil {
|
||||
defer func() { _ = r.store.Close() }()
|
||||
return fn()
|
||||
}
|
||||
}
|
||||
if errors.Is(openErr, store.ErrSchemaVersionMismatch) {
|
||||
if err := r.withSyncLock(func() error {
|
||||
storeFactory := r.openStore
|
||||
if storeFactory == nil {
|
||||
storeFactory = store.Open
|
||||
}
|
||||
var migrateErr error
|
||||
r.store, migrateErr = storeFactory(r.ctx, dbPath)
|
||||
if migrateErr != nil {
|
||||
return dbErr(migrateErr)
|
||||
}
|
||||
closeErr := r.store.Close()
|
||||
r.store = nil
|
||||
return closeErr
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
r.store, openErr = store.OpenReadOnly(r.ctx, dbPath)
|
||||
if openErr == nil {
|
||||
defer func() { _ = r.store.Close() }()
|
||||
return fn()
|
||||
}
|
||||
}
|
||||
return dbErr(openErr)
|
||||
}
|
||||
defer func() { _ = r.store.Close() }()
|
||||
return fn()
|
||||
}
|
||||
|
||||
func (r *runtime) withServicesAuto(withDiscord, autoShareUpdate bool, fn func() error) error {
|
||||
return r.withServicesAutoLocked(withDiscord, autoShareUpdate, false, fn)
|
||||
}
|
||||
|
||||
@ -4,6 +4,8 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -20,6 +22,8 @@ import (
|
||||
|
||||
"github.com/openclaw/discrawl/internal/config"
|
||||
discordclient "github.com/openclaw/discrawl/internal/discord"
|
||||
"github.com/openclaw/discrawl/internal/discorddesktop"
|
||||
"github.com/openclaw/discrawl/internal/report"
|
||||
"github.com/openclaw/discrawl/internal/share"
|
||||
"github.com/openclaw/discrawl/internal/store"
|
||||
"github.com/openclaw/discrawl/internal/syncer"
|
||||
@ -38,6 +42,192 @@ func TestHelpAndVersion(t *testing.T) {
|
||||
|
||||
err := Run(context.Background(), []string{"bogus"}, &out, &bytes.Buffer{})
|
||||
require.Equal(t, 2, ExitCode(err))
|
||||
require.Equal(t, 1, ExitCode(context.Canceled))
|
||||
require.Equal(t, 7, ExitCode(&cliError{code: 7, err: errors.New("custom")}))
|
||||
}
|
||||
|
||||
func TestCommandValidationEdges(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.toml")
|
||||
dbPath := filepath.Join(dir, "discrawl.db")
|
||||
cfg := config.Default()
|
||||
cfg.DBPath = dbPath
|
||||
cfg.Discord.TokenSource = "none"
|
||||
require.NoError(t, config.Write(cfgPath, cfg))
|
||||
s, err := store.Open(ctx, dbPath)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, s.Close())
|
||||
|
||||
cases := [][]string{
|
||||
{"--config", cfgPath, "--bogus"},
|
||||
{"--config", cfgPath, "search"},
|
||||
{"--config", cfgPath, "search", "--mode", "bogus", "term"},
|
||||
{"--config", cfgPath, "messages"},
|
||||
{"--config", cfgPath, "messages", "--hours", "-1", "--channel", "general"},
|
||||
{"--config", cfgPath, "messages", "--hours", "1", "--days", "1", "--channel", "general"},
|
||||
{"--config", cfgPath, "messages", "--all", "--last", "1", "--channel", "general"},
|
||||
{"--config", cfgPath, "messages", "--dm", "--sync", "--channel", "alice"},
|
||||
{"--config", cfgPath, "dms", "--hours", "-1"},
|
||||
{"--config", cfgPath, "dms", "--limit", "1", "--last", "1", "--with", "alice"},
|
||||
{"--config", cfgPath, "mentions"},
|
||||
{"--config", cfgPath, "mentions", "--days", "-1", "--target", "u1"},
|
||||
{"--config", cfgPath, "mentions", "--type", "channel", "--target", "u1"},
|
||||
{"--config", cfgPath, "digest", "--since", "-1d"},
|
||||
{"--config", cfgPath, "analytics", "wat"},
|
||||
{"--config", cfgPath, "analytics", "quiet", "extra"},
|
||||
{"--config", cfgPath, "analytics", "trends", "--weeks", "-1"},
|
||||
{"--config", cfgPath, "channels"},
|
||||
{"--config", cfgPath, "channels", "wat"},
|
||||
{"--config", cfgPath, "channels", "show"},
|
||||
{"--config", cfgPath, "status", "extra"},
|
||||
{"--config", cfgPath, "report", "extra"},
|
||||
{"--config", cfgPath, "wiretap", "extra"},
|
||||
{"--config", cfgPath, "wiretap", "--max-file-bytes", "0"},
|
||||
{"--config", cfgPath, "sync", "--source", "bogus"},
|
||||
{"--config", cfgPath, "sync", "--since", "not-time"},
|
||||
{"--config", cfgPath, "sync", "--no-update", "--update", "force"},
|
||||
{"--config", cfgPath, "publish", "--remote", ""},
|
||||
{"--config", cfgPath, "subscribe"},
|
||||
{"--config", cfgPath, "update", "extra"},
|
||||
{"--config", cfgPath, "sql", "--confirm", "select 1"},
|
||||
{"--config", cfgPath, "sql", "--unsafe", "select 1"},
|
||||
{"--config", cfgPath, "members"},
|
||||
{"--config", cfgPath, "members", "wat"},
|
||||
}
|
||||
for _, args := range cases {
|
||||
var stdout, stderr bytes.Buffer
|
||||
err := Run(ctx, args, &stdout, &stderr)
|
||||
require.Error(t, err, args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOutputBranches(t *testing.T) {
|
||||
now := time.Date(2026, 5, 8, 12, 0, 0, 0, time.UTC)
|
||||
values := []any{
|
||||
syncRunStats{
|
||||
Source: "both",
|
||||
Discord: &syncer.SyncStats{Guilds: 1, Channels: 2, Threads: 3, Members: 4, Messages: 5},
|
||||
Wiretap: &discorddesktop.Stats{
|
||||
Path: "/tmp/discord",
|
||||
FilesVisited: 1,
|
||||
FilesScanned: 2,
|
||||
FilesSkipped: 3,
|
||||
FilesUnchanged: 4,
|
||||
CacheFilesFastSkipped: 5,
|
||||
JSONObjects: 6,
|
||||
Guilds: 7,
|
||||
Channels: 8,
|
||||
Messages: 9,
|
||||
DMMessages: 10,
|
||||
DMChannels: 11,
|
||||
GuildMessages: 12,
|
||||
SkippedMessages: 13,
|
||||
SkippedChannels: 14,
|
||||
Checkpoints: 15,
|
||||
FullCache: true,
|
||||
DryRun: true,
|
||||
},
|
||||
},
|
||||
syncer.SyncStats{Guilds: 1, Channels: 2, Threads: 3, Members: 4, Messages: 5},
|
||||
discorddesktop.Stats{Path: "/tmp/discord", FilesVisited: 1, FullCache: true, DryRun: true},
|
||||
store.EmbeddingDrainStats{
|
||||
Processed: 3,
|
||||
Succeeded: 2,
|
||||
Failed: 1,
|
||||
Requeued: 4,
|
||||
RateLimited: true,
|
||||
RemainingBacklog: 5,
|
||||
Provider: "openai",
|
||||
Model: "model",
|
||||
InputVersion: "v1",
|
||||
},
|
||||
[]store.DirectMessageConversationRow{{
|
||||
ChannelID: "c1",
|
||||
Name: "Alice",
|
||||
MessageCount: 2,
|
||||
AuthorCount: 1,
|
||||
FirstMessageAt: now.Add(-time.Hour),
|
||||
LastMessageAt: now,
|
||||
}},
|
||||
store.MemberProfile{
|
||||
Member: store.MemberRow{
|
||||
GuildID: "g1",
|
||||
UserID: "u1",
|
||||
Username: "peter",
|
||||
DisplayName: "Peter",
|
||||
JoinedAt: now,
|
||||
XHandle: "steipete",
|
||||
GitHubLogin: "steipete",
|
||||
Website: "https://steipete.me",
|
||||
Pronouns: "he/him",
|
||||
Location: "Vienna",
|
||||
Bio: "Maintainer",
|
||||
URLs: []string{"https://example.com"},
|
||||
},
|
||||
MessageCount: 1,
|
||||
FirstMessageAt: now.Add(-time.Hour),
|
||||
LastMessageAt: now,
|
||||
RecentMessages: []store.MessageRow{{ChannelName: "general", CreatedAt: now, Content: "hello"}},
|
||||
},
|
||||
report.Digest{
|
||||
Since: now.Add(-24 * time.Hour),
|
||||
Until: now,
|
||||
WindowLabel: "1d",
|
||||
Channels: []report.ChannelDigest{{
|
||||
ChannelID: "c1",
|
||||
ChannelName: "general",
|
||||
Kind: "text",
|
||||
GuildID: "g1",
|
||||
Messages: 3,
|
||||
Replies: 1,
|
||||
ActiveAuthors: 2,
|
||||
TopPosters: []report.RankedCount{{Name: "Peter", Count: 2}},
|
||||
TopMentions: []report.RankedCount{{Count: 1}},
|
||||
}},
|
||||
Totals: report.DigestTotals{Messages: 3, Replies: 1, Channels: 1, ActiveAuthors: 2},
|
||||
},
|
||||
report.Quiet{
|
||||
Since: now.Add(-24 * time.Hour),
|
||||
Until: now,
|
||||
Channels: []report.QuietChannel{{
|
||||
ChannelID: "c1",
|
||||
ChannelName: "general",
|
||||
Kind: "text",
|
||||
LastMessage: "",
|
||||
DaysSilent: -1,
|
||||
}},
|
||||
Totals: report.QuietTotals{Channels: 1},
|
||||
},
|
||||
report.Trends{
|
||||
Since: now.AddDate(0, 0, -14),
|
||||
Until: now,
|
||||
Weeks: 2,
|
||||
Rows: []report.TrendsRow{{
|
||||
ChannelID: "c1",
|
||||
ChannelName: "general",
|
||||
Kind: "text",
|
||||
GuildID: "g1",
|
||||
Weekly: []report.WeeklyCount{
|
||||
{WeekStart: now.AddDate(0, 0, -14), Messages: 1},
|
||||
{WeekStart: now.AddDate(0, 0, -7), Messages: 2},
|
||||
},
|
||||
}},
|
||||
},
|
||||
map[string]any{"b": 2, "a": 1},
|
||||
}
|
||||
for _, value := range values {
|
||||
var out bytes.Buffer
|
||||
require.NoError(t, printHuman(&out, value))
|
||||
require.NotEmpty(t, out.String())
|
||||
}
|
||||
|
||||
var plain bytes.Buffer
|
||||
require.NoError(t, printPlain(&plain, report.Quiet{Channels: []report.QuietChannel{{ChannelID: "c1", ChannelName: "general", Kind: "text", GuildID: "g1", LastMessage: "now", DaysSilent: 0}}}))
|
||||
require.NoError(t, printPlain(&plain, report.Trends{Rows: []report.TrendsRow{{GuildID: "g1", ChannelID: "c1", ChannelName: "general", Kind: "text", Weekly: []report.WeeklyCount{{WeekStart: now, Messages: 2}}}}}))
|
||||
require.Error(t, printPlain(io.Discard, struct{}{}))
|
||||
require.Error(t, printHuman(io.Discard, struct{}{}))
|
||||
require.Equal(t, "this is a profile field with a very l...", trimForTable("this is a profile field with a very long text value"))
|
||||
}
|
||||
|
||||
func TestStatusSearchSQLAndListings(t *testing.T) {
|
||||
@ -135,6 +325,7 @@ func TestStatusSearchSQLAndListings(t *testing.T) {
|
||||
tests := [][]string{
|
||||
{"--config", cfgPath, "status"},
|
||||
{"--config", cfgPath, "search", "panic"},
|
||||
{"--config", cfgPath, "search", "panic", "--limit", "1"},
|
||||
{"--config", cfgPath, "search", "stack"},
|
||||
{"--config", cfgPath, "search", "--include-empty", "Peter"},
|
||||
{"--config", cfgPath, "messages", "--channel", "general", "--days", "7", "--all"},
|
||||
@ -863,6 +1054,63 @@ func TestSyncLockSerializesConcurrentRuns(t *testing.T) {
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
func TestReadCommandsDoNotWaitForSyncLock(t *testing.T) {
|
||||
if goruntime.GOOS == "windows" {
|
||||
t.Skip("sync lock timing is flaky on Windows")
|
||||
}
|
||||
ctx := context.Background()
|
||||
dir := t.TempDir()
|
||||
cfg := config.Default()
|
||||
cfg.DBPath = filepath.Join(dir, "discrawl.db")
|
||||
cfgPath := filepath.Join(dir, "config.toml")
|
||||
require.NoError(t, config.Write(cfgPath, cfg))
|
||||
|
||||
s := seedCLIStore(t, cfg.DBPath)
|
||||
require.NoError(t, s.Close())
|
||||
|
||||
firstRelease, err := acquireSyncLock(ctx, filepath.Join(dir, ".discrawl-sync.lock"))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = firstRelease() }()
|
||||
|
||||
for _, args := range [][]string{
|
||||
{"--config", cfgPath, "search", "automatic"},
|
||||
{"--config", cfgPath, "messages", "--channel", "general", "--last", "1"},
|
||||
{"--config", cfgPath, "sql", "select count(*) as total from messages"},
|
||||
} {
|
||||
runCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
|
||||
var out bytes.Buffer
|
||||
err := Run(runCtx, args, &out, &bytes.Buffer{})
|
||||
cancel()
|
||||
require.NoError(t, err, args)
|
||||
require.NotEmpty(t, out.String(), args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadCommandsMigrateOlderLocalStore(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dir := t.TempDir()
|
||||
cfg := config.Default()
|
||||
cfg.DBPath = filepath.Join(dir, "discrawl.db")
|
||||
cfgPath := filepath.Join(dir, "config.toml")
|
||||
require.NoError(t, config.Write(cfgPath, cfg))
|
||||
|
||||
s := seedCLIStore(t, cfg.DBPath)
|
||||
_, err := s.DB().ExecContext(ctx, `pragma user_version = 1`)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, s.Close())
|
||||
|
||||
var out bytes.Buffer
|
||||
require.NoError(t, Run(ctx, []string{"--config", cfgPath, "search", "automatic"}, &out, &bytes.Buffer{}))
|
||||
require.Contains(t, out.String(), "automatic updates work")
|
||||
|
||||
reader, err := store.OpenReadOnly(ctx, cfg.DBPath)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = reader.Close() }()
|
||||
var version int
|
||||
require.NoError(t, reader.DB().QueryRowContext(ctx, `pragma user_version`).Scan(&version))
|
||||
require.Equal(t, 2, version)
|
||||
}
|
||||
|
||||
func seedCLIStore(t *testing.T, path string) *store.Store {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
@ -1767,7 +2015,49 @@ func TestRuntimeHelpersAndSubcommands(t *testing.T) {
|
||||
s, err := store.Open(ctx, dbPath)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, s.UpsertChannel(ctx, store.ChannelRecord{ID: "c1", GuildID: "g1", Kind: "text", Name: "general", RawJSON: `{}`}))
|
||||
require.NoError(t, s.UpsertChannel(ctx, store.ChannelRecord{ID: "dm1", GuildID: store.DirectMessageGuildID, Kind: "dm", Name: "Alice", RawJSON: `{}`}))
|
||||
require.NoError(t, s.UpsertMember(ctx, store.MemberRecord{GuildID: "g1", UserID: "u1", Username: "peter", RoleIDsJSON: `[]`, RawJSON: `{}`}))
|
||||
base := time.Date(2026, 3, 8, 10, 0, 0, 0, time.UTC)
|
||||
require.NoError(t, s.UpsertMessages(ctx, []store.MessageMutation{
|
||||
{
|
||||
Record: store.MessageRecord{
|
||||
ID: "m1",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
ChannelName: "general",
|
||||
AuthorID: "u1",
|
||||
AuthorName: "peter",
|
||||
CreatedAt: base.Format(time.RFC3339Nano),
|
||||
Content: "hello <@u1> in <#c1>",
|
||||
NormalizedContent: "hello <@u1> in <#c1>",
|
||||
RawJSON: `{"author":{"username":"peter"}}`,
|
||||
},
|
||||
Mentions: []store.MentionEventRecord{{
|
||||
MessageID: "m1",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
AuthorID: "u1",
|
||||
TargetType: "user",
|
||||
TargetID: "u1",
|
||||
TargetName: "peter",
|
||||
EventAt: base.Format(time.RFC3339Nano),
|
||||
}},
|
||||
},
|
||||
{
|
||||
Record: store.MessageRecord{
|
||||
ID: "dm-msg",
|
||||
GuildID: store.DirectMessageGuildID,
|
||||
ChannelID: "dm1",
|
||||
ChannelName: "Alice",
|
||||
AuthorID: "u2",
|
||||
AuthorName: "Alice",
|
||||
CreatedAt: base.Add(time.Minute).Format(time.RFC3339Nano),
|
||||
Content: "private hello",
|
||||
NormalizedContent: "private hello",
|
||||
RawJSON: `{"source":"discord_desktop"}`,
|
||||
},
|
||||
},
|
||||
}))
|
||||
require.NoError(t, s.Close())
|
||||
|
||||
rt := &runtime{
|
||||
@ -1787,11 +2077,23 @@ func TestRuntimeHelpersAndSubcommands(t *testing.T) {
|
||||
require.NoError(t, rt.runMessages([]string{"--channel", "#general", "--hours", "6", "--last", "1"}))
|
||||
require.NoError(t, rt.runMessages([]string{"--channel", "#general", "--days", "7", "--all"}))
|
||||
require.NoError(t, rt.runMessages([]string{"--channel", "#general", "--days", "7", "--all", "--include-empty"}))
|
||||
require.NoError(t, rt.runMessages([]string{"--channel", "#general", "--since", "2026-03-08T00:00:00Z", "--before", "2026-03-09T00:00:00Z", "--limit", "1"}))
|
||||
require.NoError(t, rt.runMessages([]string{"--dm", "--channel", "Alice", "--last", "1"}))
|
||||
require.NoError(t, rt.runDirectMessages([]string{"--list"}))
|
||||
require.NoError(t, rt.runDirectMessages([]string{"--with", "Alice", "--search", "private", "--limit", "1"}))
|
||||
require.NoError(t, rt.runDirectMessages([]string{"--with", "Alice", "--since", "2026-03-08T00:00:00Z", "--before", "2026-03-09T00:00:00Z", "--all"}))
|
||||
require.NoError(t, rt.runMentions([]string{"--channel", "#general", "--target", "u2"}))
|
||||
require.NoError(t, rt.runMentions([]string{"--channel", "#general", "--days", "7", "--type", "user"}))
|
||||
require.NoError(t, rt.runDigest([]string{"--since", "12h", "--channel", "general", "--top-n", "2"}))
|
||||
require.NoError(t, rt.runReport([]string{"--readme", filepath.Join(dir, "README.md")}))
|
||||
require.NoError(t, rt.runSearch([]string{"--include-empty", "Peter"}))
|
||||
require.NoError(t, rt.runChannels([]string{"show", "c1"}))
|
||||
require.NoError(t, rt.runChannels([]string{"list"}))
|
||||
require.NoError(t, rt.runStatus(nil))
|
||||
require.NoError(t, rt.runAnalytics([]string{}))
|
||||
require.NoError(t, rt.runTUI([]string{"--json", "--limit", "1", "--include-empty"}))
|
||||
require.NoError(t, rt.runAnalytics([]string{"quiet", "--since", "1d"}))
|
||||
require.NoError(t, rt.runAnalytics([]string{"trends", "--weeks", "1", "--channel", "general"}))
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
@ -8,9 +8,9 @@ import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/openclaw/crawlkit/control"
|
||||
"github.com/openclaw/discrawl/internal/config"
|
||||
"github.com/openclaw/discrawl/internal/store"
|
||||
"github.com/vincentkoc/crawlkit/control"
|
||||
)
|
||||
|
||||
func (r *runtime) runMetadata(args []string) error {
|
||||
|
||||
@ -9,8 +9,8 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/openclaw/crawlkit/embed"
|
||||
"github.com/openclaw/discrawl/internal/config"
|
||||
"github.com/openclaw/discrawl/internal/embed"
|
||||
"github.com/openclaw/discrawl/internal/store"
|
||||
)
|
||||
|
||||
@ -25,7 +25,7 @@ func (r *runtime) runSearch(args []string) error {
|
||||
dm := fs.Bool("dm", false, "")
|
||||
guildsFlag := fs.String("guilds", "", "")
|
||||
guildFlag := fs.String("guild", "", "")
|
||||
if err := fs.Parse(args); err != nil {
|
||||
if err := fs.Parse(permuteSearchFlags(args)); err != nil {
|
||||
return usageErr(err)
|
||||
}
|
||||
if fs.NArg() != 1 {
|
||||
@ -67,6 +67,51 @@ func (r *runtime) runSearch(args []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
func permuteSearchFlags(args []string) []string {
|
||||
valueFlags := map[string]struct{}{
|
||||
"--mode": {},
|
||||
"--channel": {},
|
||||
"--author": {},
|
||||
"--limit": {},
|
||||
"--guilds": {},
|
||||
"--guild": {},
|
||||
}
|
||||
boolFlags := map[string]struct{}{
|
||||
"--include-empty": {},
|
||||
"--dm": {},
|
||||
}
|
||||
flags := make([]string, 0, len(args))
|
||||
positionals := make([]string, 0, len(args))
|
||||
for i := 0; i < len(args); i++ {
|
||||
arg := args[i]
|
||||
if arg == "--" {
|
||||
positionals = append(positionals, args[i+1:]...)
|
||||
break
|
||||
}
|
||||
if name, _, ok := strings.Cut(arg, "="); ok {
|
||||
if _, known := valueFlags[name]; known {
|
||||
flags = append(flags, arg)
|
||||
continue
|
||||
}
|
||||
if _, known := boolFlags[name]; known {
|
||||
flags = append(flags, arg)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if _, known := boolFlags[arg]; known {
|
||||
flags = append(flags, arg)
|
||||
continue
|
||||
}
|
||||
if _, known := valueFlags[arg]; known && i+1 < len(args) {
|
||||
flags = append(flags, arg, args[i+1])
|
||||
i++
|
||||
continue
|
||||
}
|
||||
positionals = append(positionals, arg)
|
||||
}
|
||||
return append(flags, positionals...)
|
||||
}
|
||||
|
||||
func (r *runtime) searchMessagesSemantic(opts store.SearchOptions) ([]store.SearchResult, error) {
|
||||
semanticOpts, err := r.semanticSearchOptions(opts)
|
||||
if err != nil {
|
||||
@ -112,7 +157,7 @@ func (r *runtime) semanticSearchOptions(opts store.SearchOptions) (store.Semanti
|
||||
providerFactory := r.newEmbed
|
||||
if providerFactory == nil {
|
||||
providerFactory = func(cfg config.EmbeddingsConfig) (embed.Provider, error) {
|
||||
return embed.NewProvider(cfg)
|
||||
return embed.NewProvider(crawlkitEmbeddingConfig(cfg))
|
||||
}
|
||||
}
|
||||
provider, err := providerFactory(r.cfg.Search.Embeddings)
|
||||
|
||||
@ -97,6 +97,21 @@ func hasBoolFlag(args []string, name string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func boolFlagEnabled(args []string, name string) bool {
|
||||
for _, arg := range args {
|
||||
if arg == name {
|
||||
return true
|
||||
}
|
||||
if raw, ok := strings.CutPrefix(arg, name+"="); ok {
|
||||
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||||
case "1", "t", "true", "y", "yes", "on":
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func hasHelpArg(args []string) bool {
|
||||
for _, arg := range args {
|
||||
if arg == "help" || arg == "--help" || arg == "-h" {
|
||||
|
||||
@ -34,6 +34,29 @@ func (r *runtime) withSyncLock(fn func() error) error {
|
||||
return fn()
|
||||
}
|
||||
|
||||
func (r *runtime) tryWithSyncLock(fn func() error) (bool, error) {
|
||||
if r.dbLockHeld {
|
||||
return true, fn()
|
||||
}
|
||||
lockPath, err := r.syncLockPath()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
release, locked, err := tryAcquireSyncLock(lockPath)
|
||||
if err != nil || !locked {
|
||||
return locked, err
|
||||
}
|
||||
r.dbLockHeld = true
|
||||
r.lockStarted = r.nowUTC()
|
||||
r.setSyncLockPhase("locked")
|
||||
defer func() {
|
||||
r.dbLockHeld = false
|
||||
r.lockStarted = time.Time{}
|
||||
_ = release()
|
||||
}()
|
||||
return true, fn()
|
||||
}
|
||||
|
||||
func (r *runtime) setSyncLockPhase(phase string) {
|
||||
if !r.dbLockHeld {
|
||||
return
|
||||
|
||||
@ -7,3 +7,7 @@ import "context"
|
||||
func acquireSyncLock(context.Context, string) (func() error, error) {
|
||||
return func() error { return nil }, nil
|
||||
}
|
||||
|
||||
func tryAcquireSyncLock(string) (func() error, bool, error) {
|
||||
return func() error { return nil }, true, nil
|
||||
}
|
||||
|
||||
@ -51,3 +51,29 @@ func acquireSyncLock(ctx context.Context, path string) (func() error, error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func tryAcquireSyncLock(path string) (func() error, bool, error) {
|
||||
file, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0o600)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("open sync lock: %w", err)
|
||||
}
|
||||
err = unix.Flock(int(file.Fd()), unix.LOCK_EX|unix.LOCK_NB)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
if errors.Is(err, unix.EWOULDBLOCK) || errors.Is(err, unix.EAGAIN) {
|
||||
return nil, false, nil
|
||||
}
|
||||
return nil, false, fmt.Errorf("acquire sync lock: %w", err)
|
||||
}
|
||||
_, _ = file.Seek(0, 0)
|
||||
_ = file.Truncate(0)
|
||||
_, _ = fmt.Fprintf(file, "pid=%d\n", os.Getpid())
|
||||
return func() error {
|
||||
unlockErr := unix.Flock(int(file.Fd()), unix.LOCK_UN)
|
||||
closeErr := file.Close()
|
||||
if unlockErr != nil {
|
||||
return unlockErr
|
||||
}
|
||||
return closeErr
|
||||
}, true, nil
|
||||
}
|
||||
|
||||
@ -49,3 +49,28 @@ func acquireSyncLock(ctx context.Context, path string) (func() error, error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func tryAcquireSyncLock(path string) (func() error, bool, error) {
|
||||
file, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0o600)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("open sync lock: %w", err)
|
||||
}
|
||||
handle := windows.Handle(file.Fd())
|
||||
overlapped := &windows.Overlapped{}
|
||||
err = windows.LockFileEx(handle, windows.LOCKFILE_EXCLUSIVE_LOCK|windows.LOCKFILE_FAIL_IMMEDIATELY, 0, 1, 0, overlapped)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return nil, false, nil
|
||||
}
|
||||
_, _ = file.Seek(0, 0)
|
||||
_ = file.Truncate(0)
|
||||
_, _ = fmt.Fprintf(file, "pid=%d\n", os.Getpid())
|
||||
return func() error {
|
||||
unlockErr := windows.UnlockFileEx(handle, 0, 1, 0, overlapped)
|
||||
closeErr := file.Close()
|
||||
if unlockErr != nil {
|
||||
return unlockErr
|
||||
}
|
||||
return closeErr
|
||||
}, true, nil
|
||||
}
|
||||
|
||||
@ -7,7 +7,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/vincentkoc/crawlkit/tui"
|
||||
"github.com/openclaw/crawlkit/tui"
|
||||
|
||||
"github.com/openclaw/discrawl/internal/store"
|
||||
)
|
||||
|
||||
@ -9,7 +9,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
crawlconfig "github.com/vincentkoc/crawlkit/config"
|
||||
crawlconfig "github.com/openclaw/crawlkit/config"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
package discorddesktop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@ -90,3 +93,106 @@ func TestCacheFileHasRouteHint(t *testing.T) {
|
||||
_, err = cacheFileHasRouteHint(root, "missing")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestImportAndStateEdgeBranches(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, err := Import(ctx, nil, Options{})
|
||||
require.ErrorContains(t, err, "store is required")
|
||||
|
||||
configHome := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", configHome)
|
||||
if runtime.GOOS == "linux" {
|
||||
require.Equal(t, filepath.Join(configHome, "discord"), DefaultPath())
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
s, err := store.Open(ctx, filepath.Join(dir, "discrawl.db"))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = s.Close() }()
|
||||
|
||||
stats, err := Import(ctx, s, Options{
|
||||
Path: dir,
|
||||
Now: func() time.Time { return time.Date(2026, 5, 8, 12, 0, 0, 0, time.UTC) },
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, dir, stats.Path)
|
||||
require.Equal(t, 1, stats.Checkpoints)
|
||||
|
||||
stats, err = Import(ctx, nil, Options{Path: filepath.Join(dir, "missing"), DryRun: true})
|
||||
require.NoError(t, err)
|
||||
require.True(t, stats.DryRun)
|
||||
|
||||
stats, err = Import(ctx, nil, Options{Path: dir, DryRun: true, FullCache: true})
|
||||
require.NoError(t, err)
|
||||
require.True(t, stats.FullCache)
|
||||
|
||||
require.NoError(t, s.SetSyncState(ctx, fileIndexScope(Options{}), "{not-json"))
|
||||
require.NoError(t, s.UpsertChannel(ctx, store.ChannelRecord{ID: "c1", GuildID: "g1", Kind: "text", Name: "general", RawJSON: `{}`}))
|
||||
state, err := loadScanState(ctx, s, Options{})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, state.previous)
|
||||
require.Equal(t, "general", state.channels["c1"].Name)
|
||||
}
|
||||
|
||||
func TestSnapshotFinalizeAndCommitBranches(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, err := store.Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = s.Close() }()
|
||||
|
||||
snap := newSnapshot()
|
||||
snap.messages["m-missing"] = store.MessageMutation{
|
||||
Record: store.MessageRecord{ID: "m-missing", ChannelID: "c-missing", RawJSON: `{}`},
|
||||
}
|
||||
snap.messages["m-known"] = store.MessageMutation{
|
||||
Record: store.MessageRecord{ID: "m-known", GuildID: "g1", ChannelID: "c1", ChannelName: "general", RawJSON: `{}`},
|
||||
}
|
||||
stats := &Stats{}
|
||||
totals := newScanTotals()
|
||||
unresolved := finalizeSnapshot(snap, map[string]store.ChannelRecord{
|
||||
"c1": {ID: "c1", GuildID: "g1", Kind: "text", Name: "general", RawJSON: `{}`},
|
||||
}, totals, stats, true)
|
||||
require.Equal(t, unresolvedMessages{"m-missing": "c-missing"}, unresolved)
|
||||
require.Equal(t, 1, stats.Messages)
|
||||
require.Equal(t, 1, stats.SkippedMessages)
|
||||
require.Equal(t, "general", snap.channels["c1"].Name)
|
||||
require.Equal(t, "g1", snap.guilds["g1"].ID)
|
||||
|
||||
more := unresolvedMessages{"m2": "c2"}
|
||||
mergeUnresolved(unresolved, more)
|
||||
recordUnresolved(unresolved, totals, stats)
|
||||
require.Equal(t, 2, stats.SkippedMessages)
|
||||
|
||||
state := scanState{current: map[string]fileFingerprint{}}
|
||||
candidates := []fileCandidate{{relKey: "Cache_Data/entry", fingerprint: fileFingerprint{Size: 10, ModUnixNS: 20}}}
|
||||
require.NoError(t, commitSnapshot(ctx, s, Options{DryRun: true}, state, candidates, newSnapshot(), true, stats))
|
||||
require.NoError(t, commitSnapshot(ctx, s, Options{}, state, candidates, newSnapshot(), false, stats))
|
||||
require.NoError(t, commitSnapshot(ctx, s, Options{}, state, candidates, newSnapshot(), true, stats))
|
||||
require.True(t, isImportedFingerprint(state.current["Cache_Data/entry"]))
|
||||
|
||||
require.NoError(t, checkpointScannedCandidates(ctx, s, Options{DryRun: true}, state, candidates, stats))
|
||||
require.NoError(t, checkpointScannedCandidates(ctx, s, Options{}, state, candidates, stats))
|
||||
}
|
||||
|
||||
func TestRouteHintCollectionBranches(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, "route"), []byte("https://discord.com/channels/123456789012/111111111111111121"), 0o600))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, "plain"), []byte("plain"), 0o600))
|
||||
|
||||
root, err := os.OpenRoot(dir)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = root.Close() }()
|
||||
|
||||
snap := newSnapshot()
|
||||
err = collectCacheRouteHints(context.Background(), root, []fileCandidate{
|
||||
{relPath: "missing"},
|
||||
{relPath: "plain"},
|
||||
{relPath: "route"},
|
||||
}, snap)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "123456789012", snap.routes["111111111111111121"])
|
||||
|
||||
canceled, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
require.ErrorIs(t, collectCacheRouteHints(canceled, root, []fileCandidate{{relPath: "route"}}, newSnapshot()), context.Canceled)
|
||||
}
|
||||
|
||||
@ -3,8 +3,11 @@ package discorddesktop
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/openclaw/discrawl/internal/store"
|
||||
)
|
||||
|
||||
func TestPrimitiveValueHelpers(t *testing.T) {
|
||||
@ -78,3 +81,85 @@ func TestDiscordValueFormatHelpers(t *testing.T) {
|
||||
require.Equal(t, "desktop", kindForChannelType(16, false))
|
||||
require.Equal(t, "text", kindForChannelType(0, false))
|
||||
}
|
||||
|
||||
func TestDiscordMessagePayloadHelpers(t *testing.T) {
|
||||
raw := map[string]any{
|
||||
"id": "333333333333333333",
|
||||
"channel_id": "111111111111111111",
|
||||
"guild_id": "999999999999999999",
|
||||
"type": float64(0),
|
||||
"timestamp": "2026-05-08T12:00:00Z",
|
||||
"edited_timestamp": "2026-05-08T12:05:00Z",
|
||||
"content": "hello\u200b\nworld",
|
||||
"message_reference": map[string]any{"message_id": "222222222222222222"},
|
||||
"author": map[string]any{
|
||||
"id": "444444444444444444",
|
||||
"username": "peter",
|
||||
"global_name": "Peter",
|
||||
"display_name": "Peter S",
|
||||
"discriminator": "0",
|
||||
"bot": true,
|
||||
},
|
||||
"attachments": []any{
|
||||
map[string]any{"filename": "trace.txt", "content_type": "text/plain", "size": float64(12), "url": "https://cdn.example/trace.txt"},
|
||||
map[string]any{"id": "att2"},
|
||||
"ignored",
|
||||
},
|
||||
"mentions": []any{
|
||||
map[string]any{"id": "555555555555555555", "username": "alice", "global_name": "Alice"},
|
||||
map[string]any{"username": "missing"},
|
||||
},
|
||||
"embeds": []any{
|
||||
map[string]any{"title": "Deploy", "description": "Ready"},
|
||||
map[string]any{"title": " "},
|
||||
},
|
||||
}
|
||||
at := parseDiscordTime("2026-05-08T12:00:00Z")
|
||||
attachments := parseAttachments(raw, "333333333333333333", "999999999999999999", "111111111111111111", "444444444444444444")
|
||||
require.Len(t, attachments, 2)
|
||||
require.Equal(t, "333333333333333333:0", attachments[0].AttachmentID)
|
||||
require.Equal(t, "trace.txt", attachments[0].Filename)
|
||||
require.Equal(t, "att2", attachments[1].Filename)
|
||||
require.Equal(t, []string{"trace.txt", "att2"}, attachmentText(attachments))
|
||||
|
||||
mentions := parseMentions(raw, "333333333333333333", "999999999999999999", "111111111111111111", "444444444444444444", at)
|
||||
require.Equal(t, []store.MentionEventRecord{{
|
||||
MessageID: "333333333333333333",
|
||||
GuildID: "999999999999999999",
|
||||
ChannelID: "111111111111111111",
|
||||
AuthorID: "444444444444444444",
|
||||
TargetType: "user",
|
||||
TargetID: "555555555555555555",
|
||||
TargetName: "Alice",
|
||||
EventAt: at.Format(time.RFC3339Nano),
|
||||
}}, mentions)
|
||||
|
||||
require.Equal(t, []string{"Deploy", "Ready"}, embedText(raw))
|
||||
require.Equal(t, "helloworld\ntrace.txt\natt2\nDeploy\nReady", normalizeText(raw["content"], attachmentText(attachments), embedText(raw)))
|
||||
require.Equal(t, "hidden text", cleanText("\u200bhidden\x00 text\n"))
|
||||
require.Equal(t, "222222222222222222", messageReferenceID(raw))
|
||||
require.Empty(t, messageReferenceID(map[string]any{}))
|
||||
|
||||
require.Contains(t, syntheticGuild("g1", "Guild").RawJSON, "discord_desktop")
|
||||
require.Equal(t, "dm", syntheticChannel("c1", DirectMessageGuildID, "Alice").Kind)
|
||||
require.Equal(t, "group_dm", syntheticChannel("c2", DirectMessageGuildID, "Alice, Bob").Kind)
|
||||
require.Equal(t, "channel-123456", syntheticChannel("123456123456", "g1", "").Name)
|
||||
require.Contains(t, channelRawJSON(raw, "c1", "g1", "general", "text"), `"kind":"text"`)
|
||||
require.Contains(t, messageRawJSON(raw, "333333333333333333", "999999999999999999", "111111111111111111", "444444444444444444"), "desktop_cache_note")
|
||||
require.Equal(t, "Alice, Bob", recipientLabel([]any{
|
||||
map[string]any{"username": "Bob"},
|
||||
map[string]any{"global_name": "Alice"},
|
||||
map[string]any{},
|
||||
}))
|
||||
|
||||
require.True(t, parseDiscordTime("2026-05-08T12:00:00.123Z").Equal(time.Date(2026, 5, 8, 12, 0, 0, 123000000, time.UTC)))
|
||||
require.True(t, parseDiscordTime("bad").IsZero())
|
||||
require.True(t, parseDiscordTime("").IsZero())
|
||||
require.False(t, snowflakeTime("175928847299117063").IsZero())
|
||||
require.True(t, snowflakeTime("bad").IsZero())
|
||||
require.Empty(t, formatOptionalTime(time.Time{}))
|
||||
require.Equal(t, "2026-05-08T12:00:00Z", formatOptionalTime(at))
|
||||
require.True(t, looksSnowflake("123456789012"))
|
||||
require.False(t, looksSnowflake("123"))
|
||||
require.False(t, looksSnowflake("12345678901x"))
|
||||
}
|
||||
|
||||
@ -1,91 +0,0 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type ollamaProvider struct {
|
||||
client *http.Client
|
||||
baseURL string
|
||||
model string
|
||||
maxInputChars int
|
||||
}
|
||||
|
||||
type ollamaEmbedRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input []string `json:"input"`
|
||||
}
|
||||
|
||||
type ollamaEmbedResponse struct {
|
||||
Model string `json:"model"`
|
||||
Embeddings [][]float32 `json:"embeddings"`
|
||||
}
|
||||
|
||||
func newOllamaProvider(settings providerSettings) Provider {
|
||||
return &ollamaProvider{
|
||||
client: settings.HTTPClient,
|
||||
baseURL: settings.BaseURL,
|
||||
model: settings.Model,
|
||||
maxInputChars: settings.MaxInputChars,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ollamaProvider) Embed(ctx context.Context, inputs []string) (EmbeddingBatch, error) {
|
||||
if len(inputs) == 0 {
|
||||
return EmbeddingBatch{Model: p.model}, nil
|
||||
}
|
||||
payload := ollamaEmbedRequest{
|
||||
Model: p.model,
|
||||
Input: trimInputs(inputs, p.maxInputChars),
|
||||
}
|
||||
var response ollamaEmbedResponse
|
||||
if err := postJSON(ctx, p.client, p.baseURL+"/api/embed", "", payload, &response); err != nil {
|
||||
return EmbeddingBatch{}, err
|
||||
}
|
||||
if len(response.Embeddings) != len(inputs) {
|
||||
return EmbeddingBatch{}, fmt.Errorf("ollama embedding response returned %d vectors for %d inputs", len(response.Embeddings), len(inputs))
|
||||
}
|
||||
dimensions, err := inferDimensions(response.Embeddings)
|
||||
if err != nil {
|
||||
return EmbeddingBatch{}, err
|
||||
}
|
||||
model := response.Model
|
||||
if model == "" {
|
||||
model = p.model
|
||||
}
|
||||
return EmbeddingBatch{Model: model, Dimensions: dimensions, Vectors: response.Embeddings}, nil
|
||||
}
|
||||
|
||||
func postJSON(ctx context.Context, client *http.Client, endpoint, apiKey string, payload any, target any) error {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal embedding request: %w", err)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("build embedding request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
if apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("embedding request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
msg, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
return &HTTPError{StatusCode: resp.StatusCode, Body: string(msg)}
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(target); err != nil {
|
||||
return fmt.Errorf("decode embedding response: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -1,82 +0,0 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type openAICompatibleProvider struct {
|
||||
client *http.Client
|
||||
baseURL string
|
||||
apiKey string
|
||||
model string
|
||||
maxInputChars int
|
||||
}
|
||||
|
||||
type openAIEmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input []string `json:"input"`
|
||||
}
|
||||
|
||||
type openAIEmbeddingResponse struct {
|
||||
Model string `json:"model"`
|
||||
Data []openAIEmbeddingItem `json:"data"`
|
||||
}
|
||||
|
||||
type openAIEmbeddingItem struct {
|
||||
Index *int `json:"index"`
|
||||
Embedding []float32 `json:"embedding"`
|
||||
}
|
||||
|
||||
func newOpenAICompatibleProvider(settings providerSettings) Provider {
|
||||
return &openAICompatibleProvider{
|
||||
client: settings.HTTPClient,
|
||||
baseURL: settings.BaseURL,
|
||||
apiKey: settings.APIKey,
|
||||
model: settings.Model,
|
||||
maxInputChars: settings.MaxInputChars,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *openAICompatibleProvider) Embed(ctx context.Context, inputs []string) (EmbeddingBatch, error) {
|
||||
if len(inputs) == 0 {
|
||||
return EmbeddingBatch{Model: p.model}, nil
|
||||
}
|
||||
payload := openAIEmbeddingRequest{
|
||||
Model: p.model,
|
||||
Input: trimInputs(inputs, p.maxInputChars),
|
||||
}
|
||||
var response openAIEmbeddingResponse
|
||||
if err := postJSON(ctx, p.client, p.baseURL+"/embeddings", p.apiKey, payload, &response); err != nil {
|
||||
return EmbeddingBatch{}, err
|
||||
}
|
||||
if len(response.Data) != len(inputs) {
|
||||
return EmbeddingBatch{}, fmt.Errorf("openai-compatible embedding response returned %d vectors for %d inputs", len(response.Data), len(inputs))
|
||||
}
|
||||
vectors := make([][]float32, len(inputs))
|
||||
seen := make([]bool, len(inputs))
|
||||
for position, item := range response.Data {
|
||||
index := position
|
||||
if item.Index != nil {
|
||||
index = *item.Index
|
||||
}
|
||||
if index < 0 || index >= len(inputs) {
|
||||
return EmbeddingBatch{}, fmt.Errorf("openai-compatible embedding response index %d out of range", index)
|
||||
}
|
||||
if seen[index] {
|
||||
return EmbeddingBatch{}, fmt.Errorf("openai-compatible embedding response duplicated index %d", index)
|
||||
}
|
||||
seen[index] = true
|
||||
vectors[index] = item.Embedding
|
||||
}
|
||||
dimensions, err := inferDimensions(vectors)
|
||||
if err != nil {
|
||||
return EmbeddingBatch{}, err
|
||||
}
|
||||
model := response.Model
|
||||
if model == "" {
|
||||
model = p.model
|
||||
}
|
||||
return EmbeddingBatch{Model: model, Dimensions: dimensions, Vectors: vectors}, nil
|
||||
}
|
||||
@ -1,310 +0,0 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/openclaw/discrawl/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
ProviderOpenAI = "openai"
|
||||
ProviderOllama = "ollama"
|
||||
ProviderLlamaCpp = "llamacpp"
|
||||
ProviderOpenAICompatible = "openai_compatible"
|
||||
DefaultOpenAIBaseURL = "https://api.openai.com/v1"
|
||||
DefaultOllamaBaseURL = "http://127.0.0.1:11434"
|
||||
DefaultLlamaCppBaseURL = "http://127.0.0.1:8080/v1"
|
||||
DefaultOpenAIModel = "text-embedding-3-small"
|
||||
DefaultLocalEmbeddingModel = "nomic-embed-text"
|
||||
DefaultBatchSize = 64
|
||||
DefaultMaxInputChars = 12000
|
||||
DefaultRequestTimeout = 2 * time.Minute
|
||||
DefaultProbeTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
// Provider is the narrow embedding surface used by later queue/search work.
|
||||
type Provider interface {
|
||||
Embed(ctx context.Context, inputs []string) (EmbeddingBatch, error)
|
||||
}
|
||||
|
||||
type EmbeddingBatch struct {
|
||||
Model string
|
||||
Dimensions int
|
||||
Vectors [][]float32
|
||||
}
|
||||
|
||||
type HTTPError struct {
|
||||
StatusCode int
|
||||
Body string
|
||||
}
|
||||
|
||||
func (e *HTTPError) Error() string {
|
||||
return fmt.Sprintf("embedding request failed with HTTP %d: %s", e.StatusCode, e.Body)
|
||||
}
|
||||
|
||||
func IsRateLimitError(err error) bool {
|
||||
var httpErr *HTTPError
|
||||
return errors.As(err, &httpErr) && httpErr.StatusCode == http.StatusTooManyRequests
|
||||
}
|
||||
|
||||
type CheckResult struct {
|
||||
Provider string
|
||||
Model string
|
||||
BaseURL string
|
||||
Status string
|
||||
Warning string
|
||||
Probed bool
|
||||
}
|
||||
|
||||
type Option func(*providerOptions)
|
||||
|
||||
type providerOptions struct {
|
||||
httpClient *http.Client
|
||||
timeoutOverride time.Duration
|
||||
}
|
||||
|
||||
type providerSettings struct {
|
||||
Name string
|
||||
Model string
|
||||
BaseURL string
|
||||
APIKey string
|
||||
MaxInputChars int
|
||||
Timeout time.Duration
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
func WithHTTPClient(client *http.Client) Option {
|
||||
return func(opts *providerOptions) {
|
||||
opts.httpClient = client
|
||||
}
|
||||
}
|
||||
|
||||
func WithRequestTimeout(timeout time.Duration) Option {
|
||||
return func(opts *providerOptions) {
|
||||
opts.timeoutOverride = timeout
|
||||
}
|
||||
}
|
||||
|
||||
func NewProvider(cfg config.EmbeddingsConfig, opts ...Option) (Provider, error) {
|
||||
settings, err := resolveProviderConfig(cfg, true, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newProvider(settings)
|
||||
}
|
||||
|
||||
func CheckProvider(ctx context.Context, cfg config.EmbeddingsConfig) CheckResult {
|
||||
settings, err := resolveProviderConfig(cfg, true, WithRequestTimeout(DefaultProbeTimeout))
|
||||
if err != nil {
|
||||
return CheckResult{
|
||||
Provider: normalizedProviderName(cfg.Provider),
|
||||
Model: strings.TrimSpace(cfg.Model),
|
||||
BaseURL: strings.TrimSpace(cfg.BaseURL),
|
||||
Status: "warning",
|
||||
Warning: err.Error(),
|
||||
}
|
||||
}
|
||||
result := CheckResult{
|
||||
Provider: settings.Name,
|
||||
Model: settings.Model,
|
||||
BaseURL: settings.BaseURL,
|
||||
Status: "ok",
|
||||
}
|
||||
if !shouldProbe(settings) {
|
||||
return result
|
||||
}
|
||||
provider, err := newProvider(settings)
|
||||
if err != nil {
|
||||
result.Status = "warning"
|
||||
result.Warning = err.Error()
|
||||
return result
|
||||
}
|
||||
probeCtx, cancel := context.WithTimeout(ctx, DefaultProbeTimeout)
|
||||
defer cancel()
|
||||
if _, err := provider.Embed(probeCtx, []string{"discrawl probe"}); err != nil {
|
||||
result.Status = "warning"
|
||||
result.Warning = err.Error()
|
||||
return result
|
||||
}
|
||||
result.Probed = true
|
||||
return result
|
||||
}
|
||||
|
||||
func resolveProviderConfig(cfg config.EmbeddingsConfig, validateAPIKey bool, opts ...Option) (providerSettings, error) {
|
||||
options := providerOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
name := normalizedProviderName(cfg.Provider)
|
||||
if name == "" {
|
||||
name = ProviderOpenAI
|
||||
}
|
||||
model := strings.TrimSpace(cfg.Model)
|
||||
if model == "" {
|
||||
model = defaultModel(name)
|
||||
}
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(cfg.BaseURL), "/")
|
||||
if baseURL == "" {
|
||||
switch name {
|
||||
case ProviderOpenAI:
|
||||
baseURL = DefaultOpenAIBaseURL
|
||||
case ProviderOllama:
|
||||
baseURL = DefaultOllamaBaseURL
|
||||
case ProviderLlamaCpp:
|
||||
baseURL = DefaultLlamaCppBaseURL
|
||||
case ProviderOpenAICompatible:
|
||||
return providerSettings{}, fmt.Errorf("embedding provider %q requires base_url", name)
|
||||
}
|
||||
}
|
||||
timeout := DefaultRequestTimeout
|
||||
if strings.TrimSpace(cfg.RequestTimeout) != "" {
|
||||
parsed, err := time.ParseDuration(cfg.RequestTimeout)
|
||||
if err != nil {
|
||||
return providerSettings{}, fmt.Errorf("parse embeddings request_timeout: %w", err)
|
||||
}
|
||||
if parsed <= 0 {
|
||||
return providerSettings{}, errors.New("embeddings request_timeout must be positive")
|
||||
}
|
||||
timeout = parsed
|
||||
}
|
||||
if options.timeoutOverride > 0 && options.timeoutOverride < timeout {
|
||||
timeout = options.timeoutOverride
|
||||
}
|
||||
maxInputChars := cfg.MaxInputChars
|
||||
if maxInputChars <= 0 {
|
||||
maxInputChars = DefaultMaxInputChars
|
||||
}
|
||||
switch name {
|
||||
case ProviderOpenAI, ProviderOllama, ProviderLlamaCpp, ProviderOpenAICompatible:
|
||||
default:
|
||||
return providerSettings{}, fmt.Errorf("unsupported embedding provider %q", name)
|
||||
}
|
||||
apiKey, err := resolveAPIKey(name, cfg.APIKeyEnv, validateAPIKey)
|
||||
if err != nil {
|
||||
return providerSettings{}, err
|
||||
}
|
||||
client := options.httpClient
|
||||
if client == nil {
|
||||
client = &http.Client{Timeout: timeout}
|
||||
}
|
||||
if _, err := url.ParseRequestURI(baseURL); err != nil {
|
||||
return providerSettings{}, fmt.Errorf("invalid embeddings base_url %q: %w", baseURL, err)
|
||||
}
|
||||
return providerSettings{
|
||||
Name: name,
|
||||
Model: model,
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
MaxInputChars: maxInputChars,
|
||||
Timeout: timeout,
|
||||
HTTPClient: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newProvider(settings providerSettings) (Provider, error) {
|
||||
switch settings.Name {
|
||||
case ProviderOllama:
|
||||
return newOllamaProvider(settings), nil
|
||||
case ProviderOpenAI, ProviderLlamaCpp, ProviderOpenAICompatible:
|
||||
return newOpenAICompatibleProvider(settings), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported embedding provider %q", settings.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func resolveAPIKey(provider, apiKeyEnv string, validate bool) (string, error) {
|
||||
envName := strings.TrimSpace(apiKeyEnv)
|
||||
required := provider == ProviderOpenAI
|
||||
if envName == "" {
|
||||
if required {
|
||||
envName = "OPENAI_API_KEY"
|
||||
} else {
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
value := strings.TrimSpace(os.Getenv(envName))
|
||||
if value == "" {
|
||||
if required || validate {
|
||||
return "", fmt.Errorf("embedding provider %q requires API key env %s", provider, envName)
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func normalizedProviderName(provider string) string {
|
||||
return strings.ToLower(strings.TrimSpace(provider))
|
||||
}
|
||||
|
||||
func defaultModel(provider string) string {
|
||||
switch provider {
|
||||
case ProviderOllama, ProviderLlamaCpp:
|
||||
return DefaultLocalEmbeddingModel
|
||||
default:
|
||||
return DefaultOpenAIModel
|
||||
}
|
||||
}
|
||||
|
||||
func shouldProbe(settings providerSettings) bool {
|
||||
switch settings.Name {
|
||||
case ProviderOllama, ProviderLlamaCpp:
|
||||
return true
|
||||
case ProviderOpenAICompatible:
|
||||
return isLoopbackBaseURL(settings.BaseURL)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isLoopbackBaseURL(rawURL string) bool {
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
host := parsed.Hostname()
|
||||
if host == "localhost" {
|
||||
return true
|
||||
}
|
||||
ip := net.ParseIP(host)
|
||||
return ip != nil && ip.IsLoopback()
|
||||
}
|
||||
|
||||
func trimInputs(inputs []string, maxChars int) []string {
|
||||
if maxChars <= 0 {
|
||||
maxChars = DefaultMaxInputChars
|
||||
}
|
||||
out := make([]string, len(inputs))
|
||||
for i, input := range inputs {
|
||||
runes := []rune(input)
|
||||
if len(runes) > maxChars {
|
||||
runes = runes[:maxChars]
|
||||
}
|
||||
out[i] = string(runes)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func inferDimensions(vectors [][]float32) (int, error) {
|
||||
dimensions := 0
|
||||
for _, vector := range vectors {
|
||||
if len(vector) == 0 {
|
||||
return 0, errors.New("embedding response contained an empty vector")
|
||||
}
|
||||
if dimensions == 0 {
|
||||
dimensions = len(vector)
|
||||
continue
|
||||
}
|
||||
if len(vector) != dimensions {
|
||||
return 0, fmt.Errorf("embedding response dimensions mismatch: got %d want %d", len(vector), dimensions)
|
||||
}
|
||||
}
|
||||
return dimensions, nil
|
||||
}
|
||||
@ -1,319 +0,0 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/openclaw/discrawl/internal/config"
|
||||
)
|
||||
|
||||
func TestOllamaProviderEmbeds(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/api/embed", r.URL.Path)
|
||||
assert.Equal(t, http.MethodPost, r.Method)
|
||||
var req ollamaEmbedRequest
|
||||
assert.NoError(t, json.NewDecoder(r.Body).Decode(&req))
|
||||
assert.Equal(t, "nomic-embed-text", req.Model)
|
||||
assert.Equal(t, []string{"abcd", "xy"}, req.Input)
|
||||
_, _ = w.Write([]byte(`{"model":"nomic-embed-text","embeddings":[[1,2,3],[4,5,6]]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider, err := NewProvider(config.EmbeddingsConfig{
|
||||
Provider: ProviderOllama,
|
||||
Model: "nomic-embed-text",
|
||||
BaseURL: server.URL,
|
||||
MaxInputChars: 4,
|
||||
RequestTimeout: "5s",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
batch, err := provider.Embed(context.Background(), []string{"abcdef", "xy"})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "nomic-embed-text", batch.Model)
|
||||
require.Equal(t, 3, batch.Dimensions)
|
||||
require.Equal(t, [][]float32{{1, 2, 3}, {4, 5, 6}}, batch.Vectors)
|
||||
}
|
||||
|
||||
func TestOpenAICompatibleProviderEmbedsAndUsesAuth(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/embeddings", r.URL.Path)
|
||||
assert.Equal(t, "Bearer secret", r.Header.Get("Authorization"))
|
||||
var req openAIEmbeddingRequest
|
||||
assert.NoError(t, json.NewDecoder(r.Body).Decode(&req))
|
||||
assert.Equal(t, "local-model", req.Model)
|
||||
assert.Equal(t, []string{"one", "two"}, req.Input)
|
||||
_, _ = w.Write([]byte(`{
|
||||
"model":"local-model",
|
||||
"data":[
|
||||
{"index":1,"embedding":[3,4]},
|
||||
{"index":0,"embedding":[1,2]}
|
||||
]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
t.Setenv("DISCRAWL_EMBED_KEY", "secret")
|
||||
|
||||
provider, err := NewProvider(config.EmbeddingsConfig{
|
||||
Provider: ProviderOpenAICompatible,
|
||||
Model: "local-model",
|
||||
BaseURL: server.URL,
|
||||
APIKeyEnv: "DISCRAWL_EMBED_KEY",
|
||||
RequestTimeout: "5s",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
batch, err := provider.Embed(context.Background(), []string{"one", "two"})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "local-model", batch.Model)
|
||||
require.Equal(t, 2, batch.Dimensions)
|
||||
require.Equal(t, [][]float32{{1, 2}, {3, 4}}, batch.Vectors)
|
||||
}
|
||||
|
||||
func TestProviderFactoryDefaultsAndValidation(t *testing.T) {
|
||||
t.Setenv("OPENAI_API_KEY", "openai-secret")
|
||||
|
||||
openAI, err := resolveProviderConfig(config.EmbeddingsConfig{
|
||||
Provider: ProviderOpenAI,
|
||||
RequestTimeout: "5s",
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, DefaultOpenAIBaseURL, openAI.BaseURL)
|
||||
require.Equal(t, DefaultOpenAIModel, openAI.Model)
|
||||
require.Equal(t, "openai-secret", openAI.APIKey)
|
||||
|
||||
ollama, err := resolveProviderConfig(config.EmbeddingsConfig{
|
||||
Provider: ProviderOllama,
|
||||
RequestTimeout: "5s",
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, DefaultOllamaBaseURL, ollama.BaseURL)
|
||||
require.Equal(t, DefaultLocalEmbeddingModel, ollama.Model)
|
||||
|
||||
llamaCpp, err := resolveProviderConfig(config.EmbeddingsConfig{
|
||||
Provider: ProviderLlamaCpp,
|
||||
RequestTimeout: "5s",
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, DefaultLlamaCppBaseURL, llamaCpp.BaseURL)
|
||||
|
||||
_, err = resolveProviderConfig(config.EmbeddingsConfig{
|
||||
Provider: ProviderOpenAICompatible,
|
||||
RequestTimeout: "5s",
|
||||
}, true)
|
||||
require.ErrorContains(t, err, "requires base_url")
|
||||
}
|
||||
|
||||
func TestProviderFactoryRequiresOpenAIAPIKey(t *testing.T) {
|
||||
t.Setenv("OPENAI_API_KEY", "")
|
||||
|
||||
_, err := NewProvider(config.EmbeddingsConfig{
|
||||
Provider: ProviderOpenAI,
|
||||
RequestTimeout: "5s",
|
||||
})
|
||||
require.ErrorContains(t, err, "requires API key env OPENAI_API_KEY")
|
||||
}
|
||||
|
||||
func TestProviderFactoryReportsUnsupportedProviderBeforeAPIKey(t *testing.T) {
|
||||
t.Setenv("MISSING_EMBED_KEY", "")
|
||||
|
||||
_, err := NewProvider(config.EmbeddingsConfig{
|
||||
Provider: "bogus",
|
||||
APIKeyEnv: "MISSING_EMBED_KEY",
|
||||
RequestTimeout: "5s",
|
||||
})
|
||||
require.ErrorContains(t, err, "unsupported embedding provider \"bogus\"")
|
||||
}
|
||||
|
||||
func TestCheckProviderProbesLocalProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/api/embed", r.URL.Path)
|
||||
_, _ = w.Write([]byte(`{"model":"nomic-embed-text","embeddings":[[1,2]]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
result := CheckProvider(context.Background(), config.EmbeddingsConfig{
|
||||
Provider: ProviderOllama,
|
||||
Model: "nomic-embed-text",
|
||||
BaseURL: server.URL,
|
||||
RequestTimeout: "5s",
|
||||
})
|
||||
require.Equal(t, "ok", result.Status)
|
||||
require.True(t, result.Probed)
|
||||
require.Empty(t, result.Warning)
|
||||
require.Equal(t, server.URL, result.BaseURL)
|
||||
}
|
||||
|
||||
func TestCheckProviderWarnsOnLocalProbeFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "not ready", http.StatusServiceUnavailable)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
result := CheckProvider(context.Background(), config.EmbeddingsConfig{
|
||||
Provider: ProviderOllama,
|
||||
Model: "nomic-embed-text",
|
||||
BaseURL: server.URL,
|
||||
RequestTimeout: "5s",
|
||||
})
|
||||
require.Equal(t, "warning", result.Status)
|
||||
require.Contains(t, result.Warning, "HTTP 503")
|
||||
require.False(t, result.Probed)
|
||||
}
|
||||
|
||||
func TestProviderExposesRateLimitErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "rate limited", http.StatusTooManyRequests)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider, err := NewProvider(config.EmbeddingsConfig{
|
||||
Provider: ProviderOpenAICompatible,
|
||||
Model: "local-model",
|
||||
BaseURL: server.URL,
|
||||
RequestTimeout: "5s",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = provider.Embed(context.Background(), []string{"one"})
|
||||
require.ErrorContains(t, err, "HTTP 429")
|
||||
require.True(t, IsRateLimitError(err))
|
||||
}
|
||||
|
||||
func TestProviderRejectsInvalidResponses(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte(`{"data":[{"index":0,"embedding":[1]},{"index":1,"embedding":[2,3]}]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider, err := NewProvider(config.EmbeddingsConfig{
|
||||
Provider: ProviderOpenAICompatible,
|
||||
Model: "local-model",
|
||||
BaseURL: server.URL,
|
||||
RequestTimeout: "5s",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = provider.Embed(context.Background(), []string{"one", "two"})
|
||||
require.ErrorContains(t, err, "dimensions mismatch")
|
||||
}
|
||||
|
||||
func TestEmbeddingProvidersHandleEmptyInputsAndIndexErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
settings := providerSettings{
|
||||
Name: ProviderOllama,
|
||||
Model: "model",
|
||||
BaseURL: "http://127.0.0.1:1",
|
||||
MaxInputChars: 10,
|
||||
HTTPClient: http.DefaultClient,
|
||||
}
|
||||
ollama := newOllamaProvider(settings)
|
||||
batch, err := ollama.Embed(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "model", batch.Model)
|
||||
|
||||
settings.Name = ProviderOpenAICompatible
|
||||
openai := newOpenAICompatibleProvider(settings)
|
||||
batch, err = openai.Embed(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "model", batch.Model)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
inputs []string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "count",
|
||||
body: `{"data":[]}`,
|
||||
inputs: []string{"one"},
|
||||
want: "returned 0 vectors for 1 inputs",
|
||||
},
|
||||
{
|
||||
name: "range",
|
||||
body: `{"data":[{"index":2,"embedding":[1]}]}`,
|
||||
inputs: []string{"one"},
|
||||
want: "index 2 out of range",
|
||||
},
|
||||
{
|
||||
name: "duplicate",
|
||||
body: `{"data":[{"index":0,"embedding":[1]},{"index":0,"embedding":[2]}]}`,
|
||||
inputs: []string{"one", "two"},
|
||||
want: "duplicated index 0",
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte(tc.body))
|
||||
}))
|
||||
defer server.Close()
|
||||
provider, err := NewProvider(config.EmbeddingsConfig{
|
||||
Provider: ProviderOpenAICompatible,
|
||||
Model: "model",
|
||||
BaseURL: server.URL,
|
||||
RequestTimeout: "5s",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = provider.Embed(context.Background(), tc.inputs)
|
||||
require.ErrorContains(t, err, tc.want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderOptionsAndProbeDecisions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := &http.Client{Timeout: time.Second}
|
||||
settings, err := resolveProviderConfig(config.EmbeddingsConfig{
|
||||
Provider: ProviderOllama,
|
||||
BaseURL: "http://127.0.0.1:11434/",
|
||||
RequestTimeout: "30s",
|
||||
}, true, WithHTTPClient(client), WithRequestTimeout(50*time.Millisecond))
|
||||
require.NoError(t, err)
|
||||
require.Same(t, client, settings.HTTPClient)
|
||||
require.Equal(t, 50*time.Millisecond, settings.Timeout)
|
||||
require.Equal(t, "http://127.0.0.1:11434", settings.BaseURL)
|
||||
require.True(t, shouldProbe(settings))
|
||||
|
||||
require.True(t, isLoopbackBaseURL("http://localhost:8080/v1"))
|
||||
require.True(t, isLoopbackBaseURL("http://[::1]:8080/v1"))
|
||||
require.False(t, isLoopbackBaseURL("https://api.example.com/v1"))
|
||||
require.False(t, isLoopbackBaseURL("://bad"))
|
||||
require.False(t, shouldProbe(providerSettings{Name: ProviderOpenAI}))
|
||||
require.True(t, shouldProbe(providerSettings{Name: ProviderOpenAICompatible, BaseURL: "http://localhost:8080/v1"}))
|
||||
require.False(t, shouldProbe(providerSettings{Name: ProviderOpenAICompatible, BaseURL: "https://api.example.com/v1"}))
|
||||
}
|
||||
|
||||
func TestCheckProviderSkipsRemoteCompatibleProbe(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result := CheckProvider(context.Background(), config.EmbeddingsConfig{
|
||||
Provider: ProviderOpenAICompatible,
|
||||
Model: "remote-model",
|
||||
BaseURL: "https://api.example.com/v1",
|
||||
RequestTimeout: "5s",
|
||||
})
|
||||
require.Equal(t, "ok", result.Status)
|
||||
require.False(t, result.Probed)
|
||||
require.Empty(t, result.Warning)
|
||||
}
|
||||
@ -8,23 +8,26 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/openclaw/crawlkit/mirror"
|
||||
"github.com/openclaw/crawlkit/snapshot"
|
||||
"github.com/openclaw/discrawl/internal/store"
|
||||
"github.com/vincentkoc/crawlkit/mirror"
|
||||
"github.com/vincentkoc/crawlkit/snapshot"
|
||||
)
|
||||
|
||||
const (
|
||||
ManifestName = "manifest.json"
|
||||
LastImportSyncScope = "share:last_import_at"
|
||||
LastImportManifestSyncScope = "share:last_import_manifest_generated_at"
|
||||
LastImportManifestJSONScope = "share:last_import_manifest_json"
|
||||
directMessageGuildID = "@me"
|
||||
)
|
||||
|
||||
@ -159,6 +162,7 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error)
|
||||
if err != nil {
|
||||
return Manifest{}, err
|
||||
}
|
||||
manifest = enrichManifestFromGit(ctx, opts.RepoPath, "HEAD", manifest)
|
||||
opts.reportProgress(ImportProgress{Phase: "start", TotalRows: manifestRowCount(manifest)})
|
||||
restorePragmas, err := applyImportPragmas(ctx, s.DB())
|
||||
if err != nil {
|
||||
@ -262,6 +266,7 @@ func ImportIfChanged(ctx context.Context, s *store.Store, opts Options) (Manifes
|
||||
if err != nil {
|
||||
return Manifest{}, false, err
|
||||
}
|
||||
manifest = enrichManifestFromGit(ctx, opts.RepoPath, "HEAD", manifest)
|
||||
if ManifestAlreadyImported(ctx, s, manifest) {
|
||||
if opts.IncludeEmbeddings {
|
||||
if err := ImportEmbeddings(ctx, s, opts, manifest); err != nil {
|
||||
@ -273,6 +278,12 @@ func ImportIfChanged(ctx context.Context, s *store.Store, opts Options) (Manifes
|
||||
}
|
||||
return manifest, false, nil
|
||||
}
|
||||
if previous, ok := PreviousImportedManifest(ctx, s, opts); ok {
|
||||
imported, changed, err := ImportIncremental(ctx, s, opts, previous, manifest)
|
||||
if err == nil || !errors.Is(err, errIncrementalUnsupported) {
|
||||
return imported, changed, err
|
||||
}
|
||||
}
|
||||
imported, err := Import(ctx, s, opts)
|
||||
if err != nil {
|
||||
return Manifest{}, false, err
|
||||
@ -280,6 +291,81 @@ func ImportIfChanged(ctx context.Context, s *store.Store, opts Options) (Manifes
|
||||
return imported, true, nil
|
||||
}
|
||||
|
||||
var errIncrementalUnsupported = errors.New("incremental share import unsupported")
|
||||
|
||||
func ImportIncremental(ctx context.Context, s *store.Store, opts Options, previous, manifest Manifest) (Manifest, bool, error) {
|
||||
plan := snapshot.PlanIncrementalImport(snapshotManifest(previous), snapshotManifest(manifest))
|
||||
plan, supported := shareIncrementalPlan(plan)
|
||||
if !supported {
|
||||
return Manifest{}, false, errIncrementalUnsupported
|
||||
}
|
||||
if !plan.Changed() {
|
||||
if err := MarkImported(ctx, s, manifest); err != nil {
|
||||
return Manifest{}, false, err
|
||||
}
|
||||
return manifest, false, nil
|
||||
}
|
||||
opts.reportProgress(ImportProgress{Phase: "start", TotalRows: manifestRowCount(manifest)})
|
||||
restorePragmas, err := applyImportPragmas(ctx, s.DB())
|
||||
if err != nil {
|
||||
return Manifest{}, false, err
|
||||
}
|
||||
pragmasRestored := false
|
||||
defer func() {
|
||||
if !pragmasRestored {
|
||||
_ = restorePragmas(ctx)
|
||||
}
|
||||
}()
|
||||
if _, _, err := snapshot.ImportIncremental(ctx, snapshot.IncrementalImportOptions{
|
||||
DB: s.DB(),
|
||||
RootDir: opts.RepoPath,
|
||||
Current: snapshotManifest(manifest),
|
||||
Plan: plan,
|
||||
Progress: func(progress snapshot.ImportProgress) {
|
||||
opts.reportProgress(ImportProgress{
|
||||
Phase: progress.Phase,
|
||||
Table: progress.Table,
|
||||
File: progress.File,
|
||||
FileIndex: progress.FileIndex,
|
||||
FileCount: progress.FileCount,
|
||||
Rows: progress.Rows,
|
||||
TotalRows: progress.TotalRows,
|
||||
})
|
||||
},
|
||||
Filter: func(table string, row map[string]any) (bool, error) {
|
||||
return !isDirectMessageSnapshotRow(table, row), nil
|
||||
},
|
||||
DeleteTable: func(ctx context.Context, tx *sql.Tx, table string) error {
|
||||
query, args := snapshotDeleteQuery(table)
|
||||
if _, err := tx.ExecContext(ctx, query, args...); err != nil {
|
||||
return fmt.Errorf("clear %s: %w", table, err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
ImportRow: importIncrementalSnapshotRow,
|
||||
AfterImport: func(ctx context.Context, tx *sql.Tx) error {
|
||||
if err := repairImportedGuildIDs(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
if opts.IncludeEmbeddings {
|
||||
return importEmbeddings(ctx, tx, opts, manifest.Embeddings)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}); err != nil {
|
||||
return Manifest{}, false, err
|
||||
}
|
||||
if err := MarkImported(ctx, s, manifest); err != nil {
|
||||
return Manifest{}, false, err
|
||||
}
|
||||
if err := restorePragmas(ctx); err != nil {
|
||||
return Manifest{}, false, err
|
||||
}
|
||||
pragmasRestored = true
|
||||
opts.reportProgress(ImportProgress{Phase: "done", TotalRows: manifestRowCount(manifest)})
|
||||
return manifest, true, nil
|
||||
}
|
||||
|
||||
func (opts Options) reportProgress(progress ImportProgress) {
|
||||
if opts.Progress != nil {
|
||||
opts.Progress(progress)
|
||||
@ -340,7 +426,173 @@ func MarkImported(ctx context.Context, s *store.Store, manifest Manifest) error
|
||||
if manifest.GeneratedAt.IsZero() {
|
||||
return nil
|
||||
}
|
||||
return s.SetSyncState(ctx, LastImportManifestSyncScope, manifest.GeneratedAt.Format(time.RFC3339Nano))
|
||||
if err := s.SetSyncState(ctx, LastImportManifestSyncScope, manifest.GeneratedAt.Format(time.RFC3339Nano)); err != nil {
|
||||
return err
|
||||
}
|
||||
body, err := json.Marshal(manifest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal imported manifest state: %w", err)
|
||||
}
|
||||
return s.SetSyncState(ctx, LastImportManifestJSONScope, string(body))
|
||||
}
|
||||
|
||||
func PreviousImportedManifest(ctx context.Context, s *store.Store, opts Options) (Manifest, bool) {
|
||||
body, err := s.GetSyncState(ctx, LastImportManifestJSONScope)
|
||||
if err == nil && strings.TrimSpace(body) != "" {
|
||||
var manifest Manifest
|
||||
if json.Unmarshal([]byte(body), &manifest) == nil && !manifest.GeneratedAt.IsZero() {
|
||||
return manifest, true
|
||||
}
|
||||
}
|
||||
last, err := s.GetSyncState(ctx, LastImportManifestSyncScope)
|
||||
if err != nil || strings.TrimSpace(last) == "" {
|
||||
return Manifest{}, false
|
||||
}
|
||||
generatedAt, err := time.Parse(time.RFC3339Nano, last)
|
||||
if err != nil {
|
||||
return Manifest{}, false
|
||||
}
|
||||
manifest, err := manifestFromGitHistory(ctx, opts.RepoPath, generatedAt)
|
||||
if err != nil {
|
||||
return Manifest{}, false
|
||||
}
|
||||
return manifest, true
|
||||
}
|
||||
|
||||
func manifestFromGitHistory(ctx context.Context, repoPath string, generatedAt time.Time) (Manifest, error) {
|
||||
out, err := output(ctx, repoPath, "git", "log", "--format=%H", "--max-count=500", "--", ManifestName)
|
||||
if err != nil {
|
||||
return Manifest{}, err
|
||||
}
|
||||
for hash := range strings.FieldsSeq(out) {
|
||||
body, err := output(ctx, repoPath, "git", "show", hash+":"+ManifestName)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var manifest Manifest
|
||||
if err := json.Unmarshal([]byte(body), &manifest); err != nil {
|
||||
continue
|
||||
}
|
||||
if manifest.GeneratedAt.Equal(generatedAt) {
|
||||
return enrichManifestFromGit(ctx, repoPath, hash, manifest), nil
|
||||
}
|
||||
}
|
||||
return Manifest{}, fmt.Errorf("imported manifest %s not found in git history", generatedAt.Format(time.RFC3339Nano))
|
||||
}
|
||||
|
||||
func enrichManifestFromGit(ctx context.Context, repoPath, rev string, manifest Manifest) Manifest {
|
||||
if strings.TrimSpace(repoPath) == "" || manifestHasFileManifests(manifest) {
|
||||
return manifest
|
||||
}
|
||||
files, err := gitTreeFiles(ctx, repoPath, rev)
|
||||
if err != nil {
|
||||
return manifest
|
||||
}
|
||||
for i := range manifest.Tables {
|
||||
table := &manifest.Tables[i]
|
||||
if len(table.FileManifests) > 0 {
|
||||
continue
|
||||
}
|
||||
paths := table.Files
|
||||
if len(paths) == 0 && strings.TrimSpace(table.File) != "" {
|
||||
paths = []string{table.File}
|
||||
}
|
||||
table.FileManifests = make([]snapshot.FileManifest, 0, len(paths))
|
||||
for _, path := range paths {
|
||||
info, ok := files[path]
|
||||
if !ok {
|
||||
table.FileManifests = nil
|
||||
break
|
||||
}
|
||||
rows := 0
|
||||
if len(paths) == 1 {
|
||||
rows = table.Rows
|
||||
}
|
||||
table.FileManifests = append(table.FileManifests, snapshot.FileManifest{
|
||||
Path: path,
|
||||
Rows: rows,
|
||||
Size: info.size,
|
||||
SHA256: "git:" + info.object,
|
||||
})
|
||||
}
|
||||
}
|
||||
return manifest
|
||||
}
|
||||
|
||||
func manifestHasFileManifests(manifest Manifest) bool {
|
||||
for _, table := range manifest.Tables {
|
||||
if (len(table.Files) > 0 || strings.TrimSpace(table.File) != "") && len(table.FileManifests) == 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type gitTreeFile struct {
|
||||
object string
|
||||
size int64
|
||||
}
|
||||
|
||||
func gitTreeFiles(ctx context.Context, repoPath, rev string) (map[string]gitTreeFile, error) {
|
||||
if strings.TrimSpace(rev) == "" {
|
||||
rev = "HEAD"
|
||||
}
|
||||
out, err := output(ctx, repoPath, "git", "ls-tree", "-r", "-l", rev, "--", "tables")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
files := map[string]gitTreeFile{}
|
||||
for line := range strings.SplitSeq(out, "\n") {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 5 {
|
||||
continue
|
||||
}
|
||||
size, _ := strconv.ParseInt(fields[3], 10, 64)
|
||||
files[fields[4]] = gitTreeFile{object: fields[2], size: size}
|
||||
}
|
||||
return files, nil
|
||||
}
|
||||
|
||||
func snapshotManifest(manifest Manifest) snapshot.Manifest {
|
||||
return snapshot.Manifest{
|
||||
Version: manifest.Version,
|
||||
GeneratedAt: manifest.GeneratedAt,
|
||||
Tables: manifest.Tables,
|
||||
Files: manifest.Files,
|
||||
}
|
||||
}
|
||||
|
||||
func shareIncrementalPlan(plan snapshot.ImportPlan) (snapshot.ImportPlan, bool) {
|
||||
if plan.Full {
|
||||
return plan, false
|
||||
}
|
||||
out := snapshot.ImportPlan{Tables: make([]snapshot.TableImportPlan, 0, len(plan.Tables))}
|
||||
for _, tablePlan := range plan.Tables {
|
||||
switch tablePlan.Mode {
|
||||
case snapshot.TableImportSkip:
|
||||
out.Tables = append(out.Tables, tablePlan)
|
||||
case snapshot.TableImportFiles:
|
||||
switch tablePlan.Table.Name {
|
||||
case "messages":
|
||||
out.Tables = append(out.Tables, tablePlan)
|
||||
case "sync_state":
|
||||
tablePlan.Mode = snapshot.TableImportReplace
|
||||
tablePlan.Files = nil
|
||||
tablePlan.Reason = "replace sync_state to avoid stale cursors"
|
||||
out.Tables = append(out.Tables, tablePlan)
|
||||
default:
|
||||
return plan, false
|
||||
}
|
||||
case snapshot.TableImportReplace:
|
||||
if tablePlan.Table.Name != "sync_state" {
|
||||
return plan, false
|
||||
}
|
||||
out.Tables = append(out.Tables, tablePlan)
|
||||
default:
|
||||
return plan, false
|
||||
}
|
||||
}
|
||||
return out, true
|
||||
}
|
||||
|
||||
func ReadManifest(repoPath string) (Manifest, error) {
|
||||
@ -874,6 +1126,112 @@ func importValue(value any) any {
|
||||
}
|
||||
}
|
||||
|
||||
func importIncrementalSnapshotRow(ctx context.Context, tx *sql.Tx, table string, row map[string]any) error {
|
||||
if table == "message_events" || table == "mention_events" {
|
||||
delete(row, "event_id")
|
||||
}
|
||||
if err := insertOrReplaceSnapshotRow(ctx, tx, table, row); err != nil {
|
||||
return err
|
||||
}
|
||||
if table != "messages" {
|
||||
return nil
|
||||
}
|
||||
messageID := stringValue(row["id"])
|
||||
if messageID == "" {
|
||||
return nil
|
||||
}
|
||||
return upsertMessageFTSRow(ctx, tx, messageID)
|
||||
}
|
||||
|
||||
func insertOrReplaceSnapshotRow(ctx context.Context, tx *sql.Tx, table string, row map[string]any) error {
|
||||
cols := make([]string, 0, len(row))
|
||||
for col := range row {
|
||||
cols = append(cols, col)
|
||||
}
|
||||
sort.Strings(cols)
|
||||
quoted := make([]string, 0, len(cols))
|
||||
placeholders := make([]string, 0, len(cols))
|
||||
args := make([]any, 0, len(cols))
|
||||
for _, col := range cols {
|
||||
quoted = append(quoted, quoteIdent(col))
|
||||
placeholders = append(placeholders, "?")
|
||||
args = append(args, importValue(row[col]))
|
||||
}
|
||||
stmt := "insert or replace into " + quoteIdent(table) + "(" + strings.Join(quoted, ",") + ") values(" + strings.Join(placeholders, ",") + ")"
|
||||
if _, err := tx.ExecContext(ctx, stmt, args...); err != nil {
|
||||
return fmt.Errorf("insert %s: %w", table, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func upsertMessageFTSRow(ctx context.Context, tx *sql.Tx, messageID string) error {
|
||||
rowID, ok := messageFTSRowID(messageID)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if _, err := tx.ExecContext(ctx, `delete from message_fts where rowid = ?`, rowID); err != nil {
|
||||
return fmt.Errorf("delete message_fts %s: %w", messageID, err)
|
||||
}
|
||||
var (
|
||||
guildID string
|
||||
channelID string
|
||||
authorID string
|
||||
authorName string
|
||||
channelName string
|
||||
content string
|
||||
)
|
||||
if err := tx.QueryRowContext(ctx, `
|
||||
select
|
||||
m.guild_id,
|
||||
m.channel_id,
|
||||
coalesce(m.author_id, ''),
|
||||
coalesce(
|
||||
json_extract(m.raw_json, '$.member.nick'),
|
||||
json_extract(m.raw_json, '$.author.global_name'),
|
||||
json_extract(m.raw_json, '$.author.username'),
|
||||
''
|
||||
),
|
||||
coalesce(c.name, ''),
|
||||
m.normalized_content
|
||||
from messages m
|
||||
left join channels c on c.id = m.channel_id
|
||||
where m.id = ?
|
||||
`, messageID).Scan(&guildID, &channelID, &authorID, &authorName, &channelName, &content); err != nil {
|
||||
return fmt.Errorf("query message_fts %s: %w", messageID, err)
|
||||
}
|
||||
if _, err := tx.ExecContext(ctx, `
|
||||
insert into message_fts(rowid, message_id, guild_id, channel_id, author_id, author_name, channel_name, content)
|
||||
values(?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`, rowID, messageID, guildID, channelID, nullIfEmpty(authorID), authorName, channelName, content); err != nil {
|
||||
return fmt.Errorf("insert message_fts %s: %w", messageID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func messageFTSRowID(messageID string) (int64, bool) {
|
||||
if messageID == "" {
|
||||
return 0, false
|
||||
}
|
||||
rowID, err := strconv.ParseInt(messageID, 10, 64)
|
||||
if err == nil && rowID > 0 {
|
||||
return rowID, true
|
||||
}
|
||||
hash := fnv.New64a()
|
||||
_, _ = hash.Write([]byte(messageID))
|
||||
rowID = int64(hash.Sum64() & ((uint64(1) << 63) - 1))
|
||||
if rowID == 0 {
|
||||
rowID = 1
|
||||
}
|
||||
return rowID, true
|
||||
}
|
||||
|
||||
func nullIfEmpty(value string) any {
|
||||
if value == "" {
|
||||
return nil
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func stringValue(value any) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
|
||||
@ -14,6 +14,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/openclaw/crawlkit/mirror"
|
||||
"github.com/openclaw/crawlkit/snapshot"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/openclaw/discrawl/internal/store"
|
||||
@ -73,6 +75,127 @@ func TestExportImportRoundTrip(t *testing.T) {
|
||||
require.Equal(t, manifest.GeneratedAt, imported.GeneratedAt)
|
||||
}
|
||||
|
||||
func TestImportIfChangedUsesIncrementalTailImport(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
src := seedStore(t, filepath.Join(t.TempDir(), "src.db"))
|
||||
defer func() { _ = src.Close() }()
|
||||
|
||||
repo := filepath.Join(t.TempDir(), "share")
|
||||
manifest, err := Export(ctx, src, Options{RepoPath: repo, Branch: "main"})
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, tableEntry(t, manifest, "messages").FileManifests)
|
||||
|
||||
dst, err := store.Open(ctx, filepath.Join(t.TempDir(), "dst.db"))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = dst.Close() }()
|
||||
_, changed, err := ImportIfChanged(ctx, dst, Options{RepoPath: repo, Branch: "main"})
|
||||
require.NoError(t, err)
|
||||
require.True(t, changed)
|
||||
|
||||
now := time.Now().UTC().Format(time.RFC3339Nano)
|
||||
require.NoError(t, src.UpsertMessages(ctx, []store.MessageMutation{{
|
||||
Record: store.MessageRecord{
|
||||
ID: "m2",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
ChannelName: "general",
|
||||
AuthorID: "u1",
|
||||
AuthorName: "Peter",
|
||||
MessageType: 0,
|
||||
CreatedAt: now,
|
||||
Content: "delta landed fast",
|
||||
NormalizedContent: "delta landed fast",
|
||||
RawJSON: `{"author":{"username":"Peter"}}`,
|
||||
},
|
||||
}}))
|
||||
updated, err := Export(ctx, src, Options{RepoPath: repo, Branch: "main"})
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, manifest.GeneratedAt, updated.GeneratedAt)
|
||||
|
||||
var progress []ImportProgress
|
||||
imported, changed, err := ImportIfChanged(ctx, dst, Options{
|
||||
RepoPath: repo,
|
||||
Branch: "main",
|
||||
Progress: func(p ImportProgress) { progress = append(progress, p) },
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, changed)
|
||||
require.Equal(t, updated.GeneratedAt, imported.GeneratedAt)
|
||||
require.Contains(t, progressPhases(progress), "table_start")
|
||||
require.NotContains(t, progressPhases(progress), "rebuild_fts")
|
||||
|
||||
results, err := dst.SearchMessages(ctx, store.SearchOptions{Query: "delta landed", Limit: 10})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 1)
|
||||
require.Equal(t, "m2", results[0].MessageID)
|
||||
state, err := dst.GetSyncState(ctx, LastImportManifestJSONScope)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, state, `"file_manifests"`)
|
||||
}
|
||||
|
||||
func TestImportIfChangedInfersLegacyManifestFilesFromGit(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
src := seedStore(t, filepath.Join(t.TempDir(), "src.db"))
|
||||
defer func() { _ = src.Close() }()
|
||||
|
||||
repo := filepath.Join(t.TempDir(), "share")
|
||||
require.NoError(t, exec.CommandContext(ctx, "git", "init", repo).Run())
|
||||
configureGitUser(t, repo)
|
||||
manifest, err := Export(ctx, src, Options{RepoPath: repo, Branch: "main"})
|
||||
require.NoError(t, err)
|
||||
writeShareManifest(t, repo, stripFileManifests(manifest))
|
||||
require.NoError(t, exec.CommandContext(ctx, "git", "-C", repo, "add", ".").Run())
|
||||
require.NoError(t, exec.CommandContext(ctx, "git", "-C", repo, "commit", "-m", "initial snapshot").Run())
|
||||
|
||||
dst, err := store.Open(ctx, filepath.Join(t.TempDir(), "dst.db"))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = dst.Close() }()
|
||||
_, changed, err := ImportIfChanged(ctx, dst, Options{RepoPath: repo, Branch: "main"})
|
||||
require.NoError(t, err)
|
||||
require.True(t, changed)
|
||||
|
||||
now := time.Now().UTC().Format(time.RFC3339Nano)
|
||||
require.NoError(t, src.UpsertMessages(ctx, []store.MessageMutation{{
|
||||
Record: store.MessageRecord{
|
||||
ID: "m2",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
ChannelName: "general",
|
||||
AuthorID: "u1",
|
||||
AuthorName: "Peter",
|
||||
MessageType: 0,
|
||||
CreatedAt: now,
|
||||
Content: "legacy git delta",
|
||||
NormalizedContent: "legacy git delta",
|
||||
RawJSON: `{"author":{"username":"Peter"}}`,
|
||||
},
|
||||
}}))
|
||||
updated, err := Export(ctx, src, Options{RepoPath: repo, Branch: "main"})
|
||||
require.NoError(t, err)
|
||||
writeShareManifest(t, repo, stripFileManifests(updated))
|
||||
require.NoError(t, exec.CommandContext(ctx, "git", "-C", repo, "add", ".").Run())
|
||||
require.NoError(t, exec.CommandContext(ctx, "git", "-C", repo, "commit", "-m", "tail snapshot").Run())
|
||||
|
||||
previous, ok := PreviousImportedManifest(ctx, dst, Options{RepoPath: repo, Branch: "main"})
|
||||
require.True(t, ok)
|
||||
planned, supported := shareIncrementalPlan(snapshot.PlanIncrementalImport(snapshotManifest(previous), snapshotManifest(enrichManifestFromGit(ctx, repo, "HEAD", stripFileManifests(updated)))))
|
||||
require.True(t, supported, "%+v", planned)
|
||||
require.True(t, planned.Changed(), "%+v", planned)
|
||||
|
||||
var progress []ImportProgress
|
||||
_, changed, err = ImportIfChanged(ctx, dst, Options{
|
||||
RepoPath: repo,
|
||||
Branch: "main",
|
||||
Progress: func(p ImportProgress) { progress = append(progress, p) },
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, changed)
|
||||
require.NotContains(t, progressPhases(progress), "rebuild_fts")
|
||||
results, err := dst.SearchMessages(ctx, store.SearchOptions{Query: "legacy git delta", Limit: 10})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 1)
|
||||
}
|
||||
|
||||
func TestApplyImportPragmasKeepCrashRecoveryEnabled(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s := seedStore(t, filepath.Join(t.TempDir(), "dst.db"))
|
||||
@ -652,6 +775,10 @@ func TestShareSmallHelpersAndValidation(t *testing.T) {
|
||||
require.Equal(t, `insert into "messages"("id","weird""column") values(?,?)`, insertSQL("messages", []string{"id", `weird"column`}))
|
||||
require.Equal(t, "blob", exportValue([]byte("blob")))
|
||||
require.Equal(t, "plain", exportValue("plain"))
|
||||
require.Equal(t, int64(42), importValue(json.Number("42")))
|
||||
require.InDelta(t, 3.5, importValue(json.Number("3.5")), 0)
|
||||
require.Equal(t, "nope", importValue(json.Number("nope")))
|
||||
require.Equal(t, "plain", importValue("plain"))
|
||||
require.Equal(t, "plain", stringValue("plain"))
|
||||
require.Equal(t, "42", stringValue(json.Number("42")))
|
||||
require.Empty(t, stringValue(42))
|
||||
@ -662,6 +789,9 @@ func TestShareSmallHelpersAndValidation(t *testing.T) {
|
||||
query, args := snapshotExportQuery("messages")
|
||||
require.Equal(t, "select * from messages where guild_id != ?", query)
|
||||
require.Equal(t, []any{directMessageGuildID}, args)
|
||||
query, args = snapshotExportQuery("guilds")
|
||||
require.Equal(t, "select * from guilds where id != ?", query)
|
||||
require.Equal(t, []any{directMessageGuildID}, args)
|
||||
query, args = snapshotExportQuery("sync_state")
|
||||
require.Equal(t, "select * from sync_state where scope not like 'wiretap:%'", query)
|
||||
require.Nil(t, args)
|
||||
@ -672,9 +802,15 @@ func TestShareSmallHelpersAndValidation(t *testing.T) {
|
||||
query, args = snapshotDeleteQuery("channels")
|
||||
require.Equal(t, "delete from channels where guild_id != ?", query)
|
||||
require.Equal(t, []any{directMessageGuildID}, args)
|
||||
query, args = snapshotDeleteQuery("guilds")
|
||||
require.Equal(t, "delete from guilds where id != ?", query)
|
||||
require.Equal(t, []any{directMessageGuildID}, args)
|
||||
query, args = snapshotDeleteQuery("message_events")
|
||||
require.Equal(t, "delete from message_events where guild_id != ?", query)
|
||||
require.Equal(t, []any{directMessageGuildID}, args)
|
||||
query, args = snapshotDeleteQuery("sync_state")
|
||||
require.Equal(t, "delete from sync_state where scope not like 'wiretap:%'", query)
|
||||
require.Nil(t, args)
|
||||
query, args = snapshotDeleteQuery("custom")
|
||||
require.Equal(t, "delete from custom", query)
|
||||
require.Nil(t, args)
|
||||
@ -684,6 +820,20 @@ func TestShareSmallHelpersAndValidation(t *testing.T) {
|
||||
require.True(t, isDirectMessageSnapshotRow("sync_state", map[string]any{"scope": "wiretap:last_import"}))
|
||||
require.False(t, isDirectMessageSnapshotRow("sync_state", map[string]any{"scope": "share:last_import"}))
|
||||
require.False(t, isDirectMessageSnapshotRow("custom", map[string]any{"guild_id": directMessageGuildID}))
|
||||
require.True(t, isLocalOnlyGuildID(directMessageGuildID))
|
||||
require.False(t, isLocalOnlyGuildID("g1"))
|
||||
|
||||
require.Equal(t, []string{"message_id", "guild_id"}, importColumns(TableManifest{Name: "message_events", Columns: []string{"event_id", "message_id", "guild_id"}}))
|
||||
require.Equal(t, []string{"event_id", "message_id"}, importColumns(TableManifest{Name: "messages", Columns: []string{"event_id", "message_id"}}))
|
||||
require.Equal(t, 7, manifestRowCount(Manifest{
|
||||
Tables: []TableManifest{{Rows: 2}, {Rows: 3}},
|
||||
Embeddings: []EmbeddingManifest{{Rows: 2}},
|
||||
}))
|
||||
var seen []ImportProgress
|
||||
Options{Progress: func(progress ImportProgress) { seen = append(seen, progress) }}.reportProgress(ImportProgress{Phase: "phase"})
|
||||
require.Equal(t, []ImportProgress{{Phase: "phase"}}, seen)
|
||||
Options{}.reportProgress(ImportProgress{Phase: "ignored"})
|
||||
require.Equal(t, mirror.Options{RepoPath: "repo", Remote: "origin", Branch: "main"}, mirrorOptions(Options{RepoPath: "repo", Remote: "origin", Branch: "main"}))
|
||||
|
||||
var buf bytes.Buffer
|
||||
cw := &countingWriter{w: &buf}
|
||||
@ -853,6 +1003,13 @@ func writeShareManifest(t *testing.T, repo string, manifest Manifest) {
|
||||
require.NoError(t, os.WriteFile(filepath.Join(repo, ManifestName), append(body, '\n'), 0o600))
|
||||
}
|
||||
|
||||
func stripFileManifests(manifest Manifest) Manifest {
|
||||
for i := range manifest.Tables {
|
||||
manifest.Tables[i].FileManifests = nil
|
||||
}
|
||||
return manifest
|
||||
}
|
||||
|
||||
func snapshotTableText(t *testing.T, repo string, table TableManifest) string {
|
||||
t.Helper()
|
||||
return snapshotFilesText(t, repo, table.Files)
|
||||
|
||||
@ -1,15 +1,14 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/openclaw/discrawl/internal/embed"
|
||||
"github.com/openclaw/crawlkit/embed"
|
||||
"github.com/openclaw/crawlkit/vector"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -476,28 +475,23 @@ func capRunes(value string, maxChars int) string {
|
||||
return string(runes[:maxChars])
|
||||
}
|
||||
|
||||
func EncodeEmbeddingVector(vector []float32) ([]byte, error) {
|
||||
buf := bytes.NewBuffer(make([]byte, 0, len(vector)*4))
|
||||
for _, value := range vector {
|
||||
if err := binary.Write(buf, binary.LittleEndian, value); err != nil {
|
||||
return nil, fmt.Errorf("encode embedding vector: %w", err)
|
||||
}
|
||||
func EncodeEmbeddingVector(values []float32) ([]byte, error) {
|
||||
blob, err := vector.EncodeFloat32(values)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encode embedding vector: %w", err)
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
return blob, nil
|
||||
}
|
||||
|
||||
func DecodeEmbeddingVector(blob []byte) ([]float32, error) {
|
||||
if len(blob)%4 != 0 {
|
||||
return nil, fmt.Errorf("embedding blob length %d is not a float32 multiple", len(blob))
|
||||
}
|
||||
out := make([]float32, len(blob)/4)
|
||||
reader := bytes.NewReader(blob)
|
||||
for i := range out {
|
||||
if err := binary.Read(reader, binary.LittleEndian, &out[i]); err != nil {
|
||||
return nil, fmt.Errorf("decode embedding vector: %w", err)
|
||||
}
|
||||
values, err := vector.DecodeFloat32(blob)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode embedding vector: %w", err)
|
||||
}
|
||||
return out, nil
|
||||
return values, nil
|
||||
}
|
||||
|
||||
func (s *Store) EmbeddingBacklog(ctx context.Context) (int, error) {
|
||||
|
||||
@ -5,11 +5,12 @@ import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/openclaw/crawlkit/vector"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -160,7 +161,7 @@ func (s *Store) SearchMessagesSemantic(ctx context.Context, opts SemanticSearchO
|
||||
if len(opts.QueryVector) != opts.Dimensions {
|
||||
return nil, fmt.Errorf("semantic query embedding dimensions mismatch: got %d want %d", len(opts.QueryVector), opts.Dimensions)
|
||||
}
|
||||
queryNorm := vectorNorm(opts.QueryVector)
|
||||
queryNorm := vector.Norm(opts.QueryVector)
|
||||
if queryNorm == 0 {
|
||||
return nil, errors.New("semantic query embedding returned a zero vector")
|
||||
}
|
||||
@ -236,15 +237,18 @@ func (s *Store) SearchMessagesSemantic(ctx context.Context, opts SemanticSearchO
|
||||
if dimensions != opts.Dimensions {
|
||||
return nil, fmt.Errorf("stored embedding dimensions mismatch for message %s: got %d want %d", row.MessageID, dimensions, opts.Dimensions)
|
||||
}
|
||||
vector, err := DecodeEmbeddingVector(blob)
|
||||
storedVector, err := DecodeEmbeddingVector(blob)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode embedding for message %s: %w", row.MessageID, err)
|
||||
}
|
||||
if len(vector) != dimensions {
|
||||
return nil, fmt.Errorf("stored embedding vector length mismatch for message %s: got %d want %d", row.MessageID, len(vector), dimensions)
|
||||
if len(storedVector) != dimensions {
|
||||
return nil, fmt.Errorf("stored embedding vector length mismatch for message %s: got %d want %d", row.MessageID, len(storedVector), dimensions)
|
||||
}
|
||||
score, err := cosineSimilarity(opts.QueryVector, queryNorm, vector)
|
||||
score, err := vector.CosineSimilarity(opts.QueryVector, queryNorm, storedVector)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "candidate vector is zero") {
|
||||
return nil, fmt.Errorf("score embedding for message %s: stored embedding vector is zero", row.MessageID)
|
||||
}
|
||||
return nil, fmt.Errorf("score embedding for message %s: %w", row.MessageID, err)
|
||||
}
|
||||
row.CreatedAt = parseTime(created)
|
||||
@ -328,26 +332,23 @@ func fuseSearchResults(ftsResults, semanticResults []SearchResult, limit int) []
|
||||
if limit <= 0 {
|
||||
limit = 20
|
||||
}
|
||||
entries := make(map[string]*hybridSearchEntry, len(ftsResults)+len(semanticResults))
|
||||
addResults := func(results []SearchResult, weight float64, fts bool) {
|
||||
for index, result := range results {
|
||||
entry := entries[result.MessageID]
|
||||
if entry == nil {
|
||||
entry = &hybridSearchEntry{result: result}
|
||||
entries[result.MessageID] = entry
|
||||
}
|
||||
if fts {
|
||||
entry.hasFTS = true
|
||||
}
|
||||
entry.score += weight / (rrfK + float64(index+1))
|
||||
}
|
||||
id := func(result SearchResult) string {
|
||||
return result.MessageID
|
||||
}
|
||||
addResults(ftsResults, ftsRRFWeight, true)
|
||||
addResults(semanticResults, semanticRRFWeight, false)
|
||||
|
||||
merged := make([]hybridSearchEntry, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
merged = append(merged, *entry)
|
||||
ftsIDs := make(map[string]struct{}, len(ftsResults))
|
||||
for _, result := range ftsResults {
|
||||
ftsIDs[result.MessageID] = struct{}{}
|
||||
}
|
||||
fused := vector.ReciprocalRankFusion(
|
||||
[][]SearchResult{ftsResults, semanticResults},
|
||||
[]func(SearchResult) string{id, id},
|
||||
[]float64{ftsRRFWeight, semanticRRFWeight},
|
||||
rrfK,
|
||||
)
|
||||
merged := make([]hybridSearchEntry, 0, len(fused))
|
||||
for _, entry := range fused {
|
||||
_, hasFTS := ftsIDs[entry.Item.MessageID]
|
||||
merged = append(merged, hybridSearchEntry{result: entry.Item, score: entry.Score, hasFTS: hasFTS})
|
||||
}
|
||||
sort.SliceStable(merged, func(i, j int) bool {
|
||||
if merged[i].score != merged[j].score {
|
||||
@ -490,29 +491,6 @@ func (s *Store) searchFallback(ctx context.Context, opts SearchOptions) ([]Searc
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func cosineSimilarity(query []float32, queryNorm float64, vector []float32) (float64, error) {
|
||||
if len(vector) != len(query) {
|
||||
return 0, fmt.Errorf("dimensions mismatch: got %d want %d", len(vector), len(query))
|
||||
}
|
||||
vectorNorm := vectorNorm(vector)
|
||||
if vectorNorm == 0 {
|
||||
return 0, errors.New("stored embedding vector is zero")
|
||||
}
|
||||
var dot float64
|
||||
for i := range query {
|
||||
dot += float64(query[i]) * float64(vector[i])
|
||||
}
|
||||
return dot / (queryNorm * vectorNorm), nil
|
||||
}
|
||||
|
||||
func vectorNorm(vector []float32) float64 {
|
||||
var sum float64
|
||||
for _, value := range vector {
|
||||
sum += float64(value) * float64(value)
|
||||
}
|
||||
return math.Sqrt(sum)
|
||||
}
|
||||
|
||||
func (s *Store) Members(ctx context.Context, guildID, query string, limit int) ([]MemberRow, error) {
|
||||
if strings.TrimSpace(query) != "" {
|
||||
return s.searchMembers(ctx, guildID, query, limit)
|
||||
|
||||
@ -3,12 +3,13 @@ package store
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
crawlstore "github.com/vincentkoc/crawlkit/store"
|
||||
crawlstore "github.com/openclaw/crawlkit/store"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -18,6 +19,8 @@ const (
|
||||
storeSchemaVersion = 2
|
||||
)
|
||||
|
||||
var ErrSchemaVersionMismatch = errors.New("database schema version mismatch")
|
||||
|
||||
type Store struct {
|
||||
db *sql.DB
|
||||
path string
|
||||
@ -135,7 +138,7 @@ func OpenReadOnly(ctx context.Context, path string) (*Store, error) {
|
||||
return nil, err
|
||||
} else if version != storeSchemaVersion {
|
||||
_ = base.Close()
|
||||
return nil, fmt.Errorf("database schema version mismatch: got %d want %d", version, storeSchemaVersion)
|
||||
return nil, fmt.Errorf("%w: got %d want %d", ErrSchemaVersionMismatch, version, storeSchemaVersion)
|
||||
}
|
||||
return store, nil
|
||||
}
|
||||
@ -179,7 +182,7 @@ func (s *Store) migrate(ctx context.Context) error {
|
||||
if version, err := s.schemaVersion(ctx); err != nil {
|
||||
return err
|
||||
} else if version != storeSchemaVersion {
|
||||
return fmt.Errorf("database schema version mismatch: got %d want %d", version, storeSchemaVersion)
|
||||
return fmt.Errorf("%w: got %d want %d", ErrSchemaVersionMismatch, version, storeSchemaVersion)
|
||||
}
|
||||
if err := s.applyQueryIndexMigration(ctx); err != nil {
|
||||
return err
|
||||
|
||||
@ -389,6 +389,99 @@ func TestStoreReadWriteAndSearch(t *testing.T) {
|
||||
require.Equal(t, "Peter", messageRows[0].AuthorName)
|
||||
}
|
||||
|
||||
func TestListMessagesWithThreadContextAndMentionLabels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = s.Close() }()
|
||||
|
||||
base := time.Date(2026, 5, 8, 12, 0, 0, 0, time.UTC)
|
||||
require.NoError(t, s.UpsertGuild(ctx, GuildRecord{ID: "g1", Name: "Guild", RawJSON: `{}`}))
|
||||
require.NoError(t, s.UpsertChannel(ctx, ChannelRecord{ID: "c1", GuildID: "g1", Kind: "text", Name: "general", RawJSON: `{}`}))
|
||||
require.NoError(t, s.UpsertChannel(ctx, ChannelRecord{ID: "c2", GuildID: "g2", Kind: "text", Name: "other", RawJSON: `{}`}))
|
||||
require.NoError(t, s.UpsertMember(ctx, MemberRecord{
|
||||
GuildID: "g1",
|
||||
UserID: "u1",
|
||||
Username: "alice",
|
||||
DisplayName: "Alice",
|
||||
RoleIDsJSON: `[]`,
|
||||
RawJSON: `{}`,
|
||||
}))
|
||||
require.NoError(t, s.UpsertMember(ctx, MemberRecord{
|
||||
GuildID: "g2",
|
||||
UserID: "u1",
|
||||
Username: "other-alice",
|
||||
DisplayName: "Other Alice",
|
||||
RoleIDsJSON: `[]`,
|
||||
RawJSON: `{}`,
|
||||
}))
|
||||
require.NoError(t, s.UpsertMessages(ctx, []MessageMutation{
|
||||
{
|
||||
Record: MessageRecord{
|
||||
ID: "root",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
ChannelName: "general",
|
||||
AuthorID: "u1",
|
||||
AuthorName: "Alice",
|
||||
CreatedAt: base.Format(time.RFC3339Nano),
|
||||
Content: "root mentions <@u1> and <#c1>",
|
||||
NormalizedContent: "root mentions <@u1> and <#c1>",
|
||||
RawJSON: `{}`,
|
||||
},
|
||||
Mentions: []MentionEventRecord{{
|
||||
MessageID: "root",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
AuthorID: "u1",
|
||||
TargetType: "role",
|
||||
TargetID: "r1",
|
||||
TargetName: "Maintainers",
|
||||
EventAt: base.Format(time.RFC3339Nano),
|
||||
}},
|
||||
},
|
||||
{
|
||||
Record: MessageRecord{
|
||||
ID: "reply",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
ChannelName: "general",
|
||||
AuthorID: "u1",
|
||||
AuthorName: "Alice",
|
||||
CreatedAt: base.Add(time.Minute).Format(time.RFC3339Nano),
|
||||
Content: "reply to root <@&r1>",
|
||||
NormalizedContent: "reply to root <@&r1>",
|
||||
ReplyToMessageID: "root",
|
||||
RawJSON: `{}`,
|
||||
},
|
||||
Mentions: []MentionEventRecord{{
|
||||
MessageID: "reply",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
AuthorID: "u1",
|
||||
TargetType: "role",
|
||||
TargetID: "r1",
|
||||
TargetName: "Maintainers",
|
||||
EventAt: base.Add(time.Minute).Format(time.RFC3339Nano),
|
||||
}},
|
||||
},
|
||||
}))
|
||||
|
||||
rows, err := s.ListMessagesWithThreadContext(ctx, MessageListOptions{Channel: "general", Since: base.Add(30 * time.Second), Limit: 1})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"reply", "root"}, messageRowIDs(rows))
|
||||
require.Equal(t, "reply to root @Maintainers", rows[0].DisplayContent)
|
||||
require.Equal(t, "root mentions @Alice and #general", rows[1].DisplayContent)
|
||||
|
||||
merged := mergeMessageRows(rows[:1], []MessageRow{rows[0], {MessageID: "other", GuildID: "g1", ChannelID: "c1"}})
|
||||
require.Equal(t, []string{"reply", "other"}, messageRowIDs(merged))
|
||||
require.Equal(t, "@fallback", replaceDiscordMention("<@missing>", "user", "missing", "fallback"))
|
||||
require.Equal(t, "#chan", replaceDiscordMention("<#c1>", "channel", "c1", "chan"))
|
||||
require.Equal(t, "<@u2>", replaceDiscordMention("<@u2>", "user", "", "blank"))
|
||||
}
|
||||
|
||||
func TestSearchMessagesPrefersRecentMessageIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@ -843,6 +936,14 @@ func searchResultIDs(results []SearchResult) []string {
|
||||
return ids
|
||||
}
|
||||
|
||||
func messageRowIDs(rows []MessageRow) []string {
|
||||
ids := make([]string, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
ids = append(ids, row.MessageID)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func TestCheckMessageFTSProbe(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@ -937,6 +1038,33 @@ func TestOpenSetsSchemaVersion(t *testing.T) {
|
||||
require.Equal(t, storeSchemaVersion, version)
|
||||
}
|
||||
|
||||
func TestOpenReadOnlySchemaChecks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
dbPath := filepath.Join(t.TempDir(), "discrawl.db")
|
||||
s, err := Open(ctx, dbPath)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, s.UpsertGuild(ctx, GuildRecord{ID: "g1", Name: "Guild", RawJSON: `{}`}))
|
||||
require.NoError(t, s.Close())
|
||||
|
||||
ro, err := OpenReadOnly(ctx, dbPath)
|
||||
require.NoError(t, err)
|
||||
status, err := ro.Status(ctx, dbPath, "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, status.GuildCount)
|
||||
require.NoError(t, ro.Close())
|
||||
|
||||
future, err := sql.Open("sqlite", dbPath)
|
||||
require.NoError(t, err)
|
||||
_, err = future.ExecContext(ctx, `pragma user_version = 999`)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, future.Close())
|
||||
|
||||
_, err = OpenReadOnly(ctx, dbPath)
|
||||
require.ErrorContains(t, err, "database schema version mismatch")
|
||||
}
|
||||
|
||||
func TestOpenFailsOnFutureSchemaVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@ -10,9 +10,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/openclaw/crawlkit/embed"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/openclaw/discrawl/internal/embed"
|
||||
)
|
||||
|
||||
func TestUpsertMessagesBatch(t *testing.T) {
|
||||
|
||||
@ -7,7 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/vincentkoc/crawlkit/progress"
|
||||
"github.com/openclaw/crawlkit/progress"
|
||||
|
||||
"github.com/openclaw/discrawl/internal/store"
|
||||
)
|
||||
@ -187,7 +187,7 @@ func (s *Syncer) syncMessageChannelsConcurrent(
|
||||
}
|
||||
|
||||
func (s *Syncer) clearUnavailableChannel(ctx context.Context, channelID string) error {
|
||||
if s.store == nil || channelID == "" {
|
||||
if s == nil || s.store == nil || channelID == "" {
|
||||
return nil
|
||||
}
|
||||
return s.store.DeleteSyncState(ctx, "channel:"+channelID+":unavailable")
|
||||
@ -616,6 +616,9 @@ func (p *messageSyncProgress) record(channel *discordgo.Channel, count int) {
|
||||
}
|
||||
|
||||
func (p *messageSyncProgress) recordSkip(channel *discordgo.Channel, err error) {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
outcome := syncErrorOutcome(err)
|
||||
p.mu.Lock()
|
||||
switch outcome {
|
||||
|
||||
214
internal/syncer/message_sync_helpers_test.go
Normal file
214
internal/syncer/message_sync_helpers_test.go
Normal file
@ -0,0 +1,214 @@
|
||||
package syncer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/openclaw/discrawl/internal/store"
|
||||
)
|
||||
|
||||
func TestMessageChannelSelectionAndTimeoutHelpers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
parent := &discordgo.Channel{ID: "forum", GuildID: "g1", Name: "forum", Type: discordgo.ChannelTypeGuildForum}
|
||||
thread := &discordgo.Channel{ID: "thread", GuildID: "g1", ParentID: "forum", Name: "thread", Type: discordgo.ChannelTypeGuildPublicThread}
|
||||
text := &discordgo.Channel{ID: "text", GuildID: "g1", Name: "text", Type: discordgo.ChannelTypeGuildText}
|
||||
voice := &discordgo.Channel{ID: "voice", GuildID: "g1", Name: "voice", Type: discordgo.ChannelTypeGuildVoice}
|
||||
|
||||
rows := filterMessageChannels([]*discordgo.Channel{nil, parent, thread, text, voice}, []string{"forum"})
|
||||
require.Equal(t, []string{"thread"}, channelIDs(rows))
|
||||
require.False(t, requestedMessageTarget(nil, nil, map[string]struct{}{}))
|
||||
require.True(t, requestedMessageTarget(text, map[string]*discordgo.Channel{"text": text}, map[string]struct{}{"text": {}}))
|
||||
require.False(t, requestedMessageTarget(thread, map[string]*discordgo.Channel{}, map[string]struct{}{"forum": {}}))
|
||||
|
||||
ctx, cancel := (*Syncer)(nil).messageChannelContext(context.Background())
|
||||
require.NoError(t, ctx.Err())
|
||||
cancel()
|
||||
require.ErrorIs(t, ctx.Err(), context.Canceled)
|
||||
|
||||
svc := New(&fakeClient{}, nil, nil)
|
||||
svc.messageChannelTimeout = time.Second
|
||||
ctx, cancel = svc.messageChannelContext(context.Background())
|
||||
defer cancel()
|
||||
_, ok := ctx.Deadline()
|
||||
require.True(t, ok)
|
||||
|
||||
parentCtx, parentCancel := context.WithDeadline(context.Background(), time.Now().Add(time.Hour))
|
||||
defer parentCancel()
|
||||
ctx, cancel = svc.messageChannelContext(parentCtx)
|
||||
defer cancel()
|
||||
deadline, ok := ctx.Deadline()
|
||||
require.True(t, ok)
|
||||
parentDeadline, _ := parentCtx.Deadline()
|
||||
require.Equal(t, parentDeadline, deadline)
|
||||
}
|
||||
|
||||
func TestChannelSyncStateHelpers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
channel := &discordgo.Channel{ID: "c1", LastMessageID: "200"}
|
||||
require.False(t, shouldSkipChannelSync(nil, channelSyncState{BackfillComplete: true}))
|
||||
require.True(t, shouldSkipChannelSync(&discordgo.Channel{ID: "c1"}, channelSyncState{BackfillComplete: true, Latest: ""}))
|
||||
require.False(t, shouldSkipChannelSync(channel, channelSyncState{BackfillComplete: true, Latest: ""}))
|
||||
require.True(t, shouldSkipChannelSync(channel, channelSyncState{BackfillComplete: true, Latest: "300"}))
|
||||
require.False(t, shouldSkipLatestOnlyChannelSync(nil, channelSyncState{Latest: "300"}))
|
||||
require.False(t, shouldSkipLatestOnlyChannelSync(channel, channelSyncState{}))
|
||||
require.True(t, shouldSkipLatestOnlyChannelSync(channel, channelSyncState{Latest: "300"}))
|
||||
|
||||
messages := []*discordgo.Message{
|
||||
{ID: "3", Timestamp: time.Date(2026, 5, 8, 12, 0, 0, 0, time.UTC)},
|
||||
{ID: "2", Timestamp: time.Date(2026, 5, 8, 11, 0, 0, 0, time.UTC)},
|
||||
{ID: "1", Timestamp: time.Date(2026, 5, 8, 10, 0, 0, 0, time.UTC)},
|
||||
}
|
||||
filtered, reached := filterMessagesSince(messages, time.Date(2026, 5, 8, 10, 30, 0, 0, time.UTC))
|
||||
require.True(t, reached)
|
||||
require.Equal(t, []string{"3", "2"}, messageIDs(filtered))
|
||||
filtered, reached = filterMessagesSince(messages, time.Time{})
|
||||
require.False(t, reached)
|
||||
require.Len(t, filtered, 3)
|
||||
}
|
||||
|
||||
func TestChannelSyncStateStoreHelpers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
s, err := store.Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = s.Close() }()
|
||||
require.NoError(t, s.UpsertChannel(ctx, store.ChannelRecord{ID: "c1", GuildID: "g1", Kind: "text", Name: "general", RawJSON: `{}`}))
|
||||
require.NoError(t, s.UpsertMessage(ctx, store.MessageRecord{
|
||||
ID: "100",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
ChannelName: "general",
|
||||
AuthorID: "u1",
|
||||
AuthorName: "User",
|
||||
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
|
||||
Content: "hello",
|
||||
NormalizedContent: "hello",
|
||||
RawJSON: `{}`,
|
||||
}))
|
||||
|
||||
svc := New(&fakeClient{}, s, nil)
|
||||
state := channelSyncState{}
|
||||
require.NoError(t, svc.seedChannelSyncState(ctx, "c1", &state))
|
||||
require.Equal(t, "100", state.Latest)
|
||||
require.Equal(t, "100", state.BackfillCursor)
|
||||
|
||||
state = channelSyncState{StoredLatest: "100"}
|
||||
require.NoError(t, svc.seedChannelSyncState(ctx, "missing-channel", &state))
|
||||
require.True(t, state.BackfillComplete)
|
||||
|
||||
require.NoError(t, s.SetSyncState(ctx, channelLatestScope("c1"), "200"))
|
||||
require.NoError(t, s.SetSyncState(ctx, channelBackfillScope("c1"), "100"))
|
||||
require.NoError(t, s.SetSyncState(ctx, channelHistoryCompleteScope("c1"), "1"))
|
||||
loaded, err := svc.loadChannelSyncState(ctx, "c1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, channelSyncState{Latest: "200", StoredLatest: "200", BackfillCursor: "100", BackfillComplete: true}, loaded)
|
||||
}
|
||||
|
||||
func TestMessageChannelSyncBranches(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
s, err := store.Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = s.Close() }()
|
||||
|
||||
svc := New(&fakeClient{}, s, nil)
|
||||
count, err := svc.syncMessageChannels(ctx, "g1", nil, SyncOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, count)
|
||||
require.NoError(t, svc.clearUnavailableChannel(ctx, ""))
|
||||
require.NoError(t, (*Syncer)(nil).clearUnavailableChannel(ctx, "c1"))
|
||||
|
||||
channel := &discordgo.Channel{ID: "c1", GuildID: "g1", Name: "general", Type: discordgo.ChannelTypeGuildText}
|
||||
client := &fakeClient{
|
||||
messages: map[string][]*discordgo.Message{
|
||||
"c1": {{
|
||||
ID: "100",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
Content: "hello",
|
||||
Timestamp: time.Now().UTC(),
|
||||
Author: &discordgo.User{ID: "u1", Username: "user"},
|
||||
}},
|
||||
},
|
||||
}
|
||||
svc = New(client, s, nil)
|
||||
count, err = svc.syncMessageChannelsSerial(ctx, "g1", []*discordgo.Channel{channel}, SyncOptions{Full: true}, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
|
||||
errChannel := &discordgo.Channel{ID: "c-err", GuildID: "g1", Name: "errors", Type: discordgo.ChannelTypeGuildText}
|
||||
client.messageErrors = map[string]error{"c-err": errors.New(`HTTP 500 Internal Server Error`)}
|
||||
count, err = svc.syncMessageChannelsSerial(ctx, "g1", []*discordgo.Channel{errChannel}, SyncOptions{Full: true}, nil)
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, count)
|
||||
|
||||
client.messageErrors = map[string]error{"c-err": errors.New("hard failure")}
|
||||
count, err = svc.syncMessageChannelsSerial(ctx, "g1", []*discordgo.Channel{errChannel}, SyncOptions{Full: true}, nil)
|
||||
require.ErrorContains(t, err, "sync channel c-err")
|
||||
require.Zero(t, count)
|
||||
}
|
||||
|
||||
func TestMessageChannelConcurrentErrorAndProgressBranches(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
s, err := store.Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = s.Close() }()
|
||||
|
||||
channels := []*discordgo.Channel{
|
||||
{ID: "c1", GuildID: "g1", Name: "one", Type: discordgo.ChannelTypeGuildText},
|
||||
{ID: "c2", GuildID: "g1", Name: "two", Type: discordgo.ChannelTypeGuildText},
|
||||
}
|
||||
client := &fakeClient{
|
||||
messages: map[string][]*discordgo.Message{
|
||||
"c1": {{
|
||||
ID: "101",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
Content: "one",
|
||||
Timestamp: time.Now().UTC(),
|
||||
Author: &discordgo.User{ID: "u1", Username: "user"},
|
||||
}},
|
||||
},
|
||||
messageErrors: map[string]error{"c2": errors.New("hard failure")},
|
||||
}
|
||||
svc := New(client, s, slog.New(slog.DiscardHandler))
|
||||
count, err := svc.syncMessageChannelsConcurrent(ctx, "g1", channels, SyncOptions{Full: true}, 2, newMessageSyncProgress(svc, "g1", len(channels), SyncOptions{Full: true, Concurrency: 2}))
|
||||
require.ErrorContains(t, err, "sync channel c2")
|
||||
require.Equal(t, 1, count)
|
||||
|
||||
progress := &messageSyncProgress{}
|
||||
progress.start(nil)
|
||||
progress.touch(nil, 1)
|
||||
progress.finish(nil)
|
||||
progress.logWaitHeartbeat()
|
||||
require.Equal(t, "skipped", syncErrorOutcome(errors.New("plain")))
|
||||
}
|
||||
|
||||
func channelIDs(channels []*discordgo.Channel) []string {
|
||||
out := make([]string, 0, len(channels))
|
||||
for _, channel := range channels {
|
||||
out = append(out, channel.ID)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func messageIDs(messages []*discordgo.Message) []string {
|
||||
out := make([]string, 0, len(messages))
|
||||
for _, message := range messages {
|
||||
out = append(out, message.ID)
|
||||
}
|
||||
return out
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user