feat(search): add hybrid message search

This commit is contained in:
MrBrain 2026-04-22 14:49:52 +08:00 committed by Peter Steinberger
parent ffc622cd5e
commit 3ea1d4aa7f
5 changed files with 432 additions and 11 deletions

View File

@ -11,6 +11,7 @@ All notable changes to `discrawl` will be documented in this file.
- `embed` now drains the queued embedding backlog in bounded batches, requeues safely on provider throttling, and drops stale stored vectors when messages no longer have embeddable content
- Git-backed snapshots now keep embedding queue state and generated vectors local to each archive, so subscribers no longer inherit misleading embedding backlog metadata. (#38) Thanks @GaosCode.
- semantic message search now ranks across the full compatible local vector set instead of only the newest candidate window. (#36) Thanks @GaosCode.
- hybrid message search now fuses FTS with local semantic vectors while avoiding embedding-provider calls when no local vectors exist. (#37) Thanks @GaosCode.
## 0.3.0 - 2026-04-21

View File

@ -634,6 +634,91 @@ func TestSearchSemanticCommandUsesStoredEmbeddings(t *testing.T) {
require.Equal(t, 2, requests)
}
func TestSearchHybridCommandFusesResults(t *testing.T) {
ctx := context.Background()
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.toml")
dbPath := filepath.Join(dir, "discrawl.db")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/embeddings", r.URL.Path)
var req struct {
Model string `json:"model"`
Input []string `json:"input"`
}
assert.NoError(t, json.NewDecoder(r.Body).Decode(&req))
assert.Equal(t, "local-model", req.Model)
assert.Equal(t, []string{"panic"}, req.Input)
_, _ = w.Write([]byte(`{"model":"local-model","data":[{"index":0,"embedding":[1,0]}]}`))
}))
defer server.Close()
cfg := config.Default()
cfg.DBPath = dbPath
cfg.Search.DefaultMode = "hybrid"
cfg.Search.Embeddings.Enabled = true
cfg.Search.Embeddings.Provider = "openai_compatible"
cfg.Search.Embeddings.Model = "local-model"
cfg.Search.Embeddings.BaseURL = server.URL
cfg.Search.Embeddings.APIKeyEnv = ""
require.NoError(t, config.Write(cfgPath, cfg))
s, err := store.Open(ctx, dbPath)
require.NoError(t, err)
base := time.Date(2026, 4, 22, 12, 0, 0, 0, time.UTC)
require.NoError(t, s.UpsertGuild(ctx, store.GuildRecord{ID: "g1", Name: "Guild", RawJSON: `{}`}))
require.NoError(t, s.UpsertChannel(ctx, store.ChannelRecord{ID: "c1", GuildID: "g1", Kind: "text", Name: "general", RawJSON: `{}`}))
require.NoError(t, s.UpsertMessage(ctx, store.MessageRecord{
ID: "m3",
GuildID: "g1",
ChannelID: "c1",
ChannelName: "general",
AuthorID: "u1",
AuthorName: "Alice",
MessageType: 0,
CreatedAt: base.Format(time.RFC3339Nano),
Content: "panic stack trace",
NormalizedContent: "panic stack trace",
RawJSON: `{"author":{"username":"Alice"}}`,
}))
require.NoError(t, s.UpsertMessage(ctx, store.MessageRecord{
ID: "m2",
GuildID: "g1",
ChannelID: "c1",
ChannelName: "general",
AuthorID: "u2",
AuthorName: "Bob",
MessageType: 0,
CreatedAt: base.Add(time.Minute).Format(time.RFC3339Nano),
Content: "database worker stalled",
NormalizedContent: "database worker stalled",
RawJSON: `{"author":{"username":"Bob"}}`,
}))
require.NoError(t, s.UpsertMessage(ctx, store.MessageRecord{
ID: "m1",
GuildID: "g1",
ChannelID: "c1",
ChannelName: "general",
AuthorID: "u3",
AuthorName: "Carol",
MessageType: 0,
CreatedAt: base.Add(2 * time.Minute).Format(time.RFC3339Nano),
Content: "panic database lock",
NormalizedContent: "panic database lock",
RawJSON: `{"author":{"username":"Carol"}}`,
}))
require.NoError(t, insertCLIEmbedding(ctx, s, "m1", "openai_compatible", "local-model", []float32{0.9, 0.1}))
require.NoError(t, insertCLIEmbedding(ctx, s, "m2", "openai_compatible", "local-model", []float32{1, 0}))
require.NoError(t, insertCLIEmbedding(ctx, s, "m3", "openai_compatible", "local-model", []float32{0, 1}))
require.NoError(t, s.Close())
var out bytes.Buffer
require.NoError(t, Run(ctx, []string{"--config", cfgPath, "search", "--limit", "3", "panic"}, &out, &bytes.Buffer{}))
require.Contains(t, out.String(), "panic database lock")
require.Contains(t, out.String(), "database worker stalled")
require.Contains(t, out.String(), "panic stack trace")
}
func TestSearchSemanticCommandErrors(t *testing.T) {
ctx := context.Background()
dir := t.TempDir()
@ -651,10 +736,6 @@ func TestSearchSemanticCommandErrors(t *testing.T) {
require.Equal(t, 2, ExitCode(err))
require.ErrorContains(t, err, `unsupported search mode "bogus"`)
err = Run(ctx, []string{"--config", cfgPath, "search", "--mode", "hybrid", "cats"}, &bytes.Buffer{}, &bytes.Buffer{})
require.Equal(t, 1, ExitCode(err))
require.ErrorContains(t, err, "hybrid search is not implemented yet")
err = Run(ctx, []string{"--config", cfgPath, "search", "--mode", "semantic", "cats"}, &bytes.Buffer{}, &bytes.Buffer{})
require.Equal(t, 1, ExitCode(err))
require.ErrorContains(t, err, "embeddings are disabled")
@ -675,6 +756,81 @@ func TestSearchSemanticCommandErrors(t *testing.T) {
require.ErrorContains(t, err, "embedding query failed")
}
func TestSearchHybridCommandFallsBackToFTS(t *testing.T) {
ctx := context.Background()
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.toml")
dbPath := filepath.Join(dir, "discrawl.db")
cfg := config.Default()
cfg.DBPath = dbPath
require.NoError(t, config.Write(cfgPath, cfg))
s, err := store.Open(ctx, dbPath)
require.NoError(t, err)
require.NoError(t, s.UpsertGuild(ctx, store.GuildRecord{ID: "g1", Name: "Guild", RawJSON: `{}`}))
require.NoError(t, s.UpsertChannel(ctx, store.ChannelRecord{ID: "c1", GuildID: "g1", Kind: "text", Name: "general", RawJSON: `{}`}))
require.NoError(t, s.UpsertMessage(ctx, store.MessageRecord{
ID: "m1",
GuildID: "g1",
ChannelID: "c1",
ChannelName: "general",
AuthorID: "u1",
AuthorName: "Alice",
MessageType: 0,
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Content: "panic exact match",
NormalizedContent: "panic exact match",
RawJSON: `{"author":{"username":"Alice"}}`,
}))
require.NoError(t, s.Close())
var out bytes.Buffer
require.NoError(t, Run(ctx, []string{"--config", cfgPath, "search", "--mode", "hybrid", "panic"}, &out, &bytes.Buffer{}))
require.Contains(t, out.String(), "panic exact match")
okRequests := 0
okServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
okRequests++
_, _ = w.Write([]byte(`{"model":"local-model","data":[{"index":0,"embedding":[1,0]}]}`))
}))
defer okServer.Close()
cfg.Search.Embeddings.Enabled = true
cfg.Search.Embeddings.Provider = "openai_compatible"
cfg.Search.Embeddings.Model = "local-model"
cfg.Search.Embeddings.BaseURL = okServer.URL
cfg.Search.Embeddings.APIKeyEnv = ""
require.NoError(t, config.Write(cfgPath, cfg))
out.Reset()
require.NoError(t, Run(ctx, []string{"--config", cfgPath, "search", "--mode", "hybrid", "panic"}, &out, &bytes.Buffer{}))
require.Contains(t, out.String(), "panic exact match")
require.Equal(t, 0, okRequests)
s, err = store.Open(ctx, dbPath)
require.NoError(t, err)
require.NoError(t, insertCLIEmbedding(ctx, s, "m1", "openai_compatible", "local-model", []float32{1, 0}))
require.NoError(t, s.Close())
failedRequests := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
failedRequests++
http.Error(w, "nope", http.StatusInternalServerError)
}))
defer server.Close()
cfg.Search.Embeddings.Enabled = true
cfg.Search.Embeddings.Provider = "openai_compatible"
cfg.Search.Embeddings.Model = "local-model"
cfg.Search.Embeddings.BaseURL = server.URL
cfg.Search.Embeddings.APIKeyEnv = ""
require.NoError(t, config.Write(cfgPath, cfg))
out.Reset()
require.NoError(t, Run(ctx, []string{"--config", cfgPath, "search", "--mode", "hybrid", "panic"}, &out, &bytes.Buffer{}))
require.Contains(t, out.String(), "panic exact match")
require.Equal(t, 1, failedRequests)
}
func insertCLIEmbedding(ctx context.Context, s *store.Store, messageID, provider, model string, vector []float32) error {
blob, err := store.EncodeEmbeddingVector(vector)
if err != nil {

View File

@ -2,6 +2,7 @@ package cli
import (
"bufio"
"errors"
"flag"
"fmt"
"io"
@ -51,15 +52,57 @@ func (r *runtime) runSearch(args []string) error {
}
return r.print(results)
case "hybrid":
return fmt.Errorf("hybrid search is not implemented yet")
results, err := r.searchMessagesHybrid(opts)
if err != nil {
return err
}
return r.print(results)
default:
return usageErr(fmt.Errorf("unsupported search mode %q", *mode))
}
}
func (r *runtime) searchMessagesSemantic(opts store.SearchOptions) ([]store.SearchResult, error) {
semanticOpts, err := r.semanticSearchOptions(opts)
if err != nil {
return nil, err
}
return r.store.SearchMessagesSemantic(r.ctx, semanticOpts)
}
func (r *runtime) searchMessagesHybrid(opts store.SearchOptions) ([]store.SearchResult, error) {
if !r.cfg.Search.Embeddings.Enabled {
return nil, fmt.Errorf("embeddings are disabled; enable [search.embeddings] first")
return r.store.SearchMessages(r.ctx, opts)
}
hasEmbeddings, err := r.store.HasMessageEmbeddings(
r.ctx,
r.cfg.Search.Embeddings.Provider,
r.cfg.Search.Embeddings.Model,
store.EmbeddingInputVersion,
)
if err != nil {
return nil, err
}
if !hasEmbeddings {
return r.store.SearchMessages(r.ctx, opts)
}
semanticOpts, err := r.semanticSearchOptions(opts)
if err != nil {
return r.store.SearchMessages(r.ctx, opts)
}
results, err := r.store.SearchMessagesHybrid(r.ctx, opts, semanticOpts)
if err != nil {
if hybridSemanticUnavailable(err) {
return r.store.SearchMessages(r.ctx, opts)
}
return nil, err
}
return results, nil
}
func (r *runtime) semanticSearchOptions(opts store.SearchOptions) (store.SemanticSearchOptions, error) {
if !r.cfg.Search.Embeddings.Enabled {
return store.SemanticSearchOptions{}, fmt.Errorf("embeddings are disabled; enable [search.embeddings] first")
}
providerFactory := r.newEmbed
if providerFactory == nil {
@ -69,21 +112,21 @@ func (r *runtime) searchMessagesSemantic(opts store.SearchOptions) ([]store.Sear
}
provider, err := providerFactory(r.cfg.Search.Embeddings)
if err != nil {
return nil, fmt.Errorf("create embedding provider: %w", err)
return store.SemanticSearchOptions{}, fmt.Errorf("create embedding provider: %w", err)
}
batch, err := provider.Embed(r.ctx, []string{opts.Query})
if err != nil {
return nil, fmt.Errorf("embedding query failed: %w", err)
return store.SemanticSearchOptions{}, fmt.Errorf("embedding query failed: %w", err)
}
if len(batch.Vectors) != 1 {
return nil, fmt.Errorf("embedding query returned %d vectors for 1 input", len(batch.Vectors))
return store.SemanticSearchOptions{}, fmt.Errorf("embedding query returned %d vectors for 1 input", len(batch.Vectors))
}
queryVector := batch.Vectors[0]
dimensions := batch.Dimensions
if dimensions == 0 {
dimensions = len(queryVector)
}
return r.store.SearchMessagesSemantic(r.ctx, store.SemanticSearchOptions{
return store.SemanticSearchOptions{
QueryVector: queryVector,
Provider: r.cfg.Search.Embeddings.Provider,
Model: r.cfg.Search.Embeddings.Model,
@ -94,7 +137,11 @@ func (r *runtime) searchMessagesSemantic(opts store.SearchOptions) ([]store.Sear
Author: opts.Author,
Limit: opts.Limit,
IncludeEmpty: opts.IncludeEmpty,
})
}, nil
}
func hybridSemanticUnavailable(err error) bool {
return errors.Is(err, store.ErrNoCompatibleEmbeddings) || strings.HasPrefix(err.Error(), "semantic query embedding ")
}
func (r *runtime) runSQL(args []string) error {

View File

@ -20,6 +20,9 @@ const (
searchCandidateCap = 5000
searchCandidateMultiple = 20
messageFTSHealthProbe = "__discrawl_probe__"
rrfK = 60.0
ftsRRFWeight = 1.0
semanticRRFWeight = 1.0
)
var ErrNoCompatibleEmbeddings = errors.New("no compatible message embeddings for provider/model/input version; run discrawl embed --rebuild")
@ -294,6 +297,80 @@ func semanticScoreLess(left, right semanticScoredResult) bool {
return left.result.MessageID > right.result.MessageID
}
func (s *Store) SearchMessagesHybrid(ctx context.Context, opts SearchOptions, semanticOpts SemanticSearchOptions) ([]SearchResult, error) {
limit := opts.Limit
if limit <= 0 {
limit = 20
}
candidateLimit := searchCandidateLimit(limit)
ftsOpts := opts
ftsOpts.Limit = candidateLimit
semanticOpts.Limit = candidateLimit
ftsResults, err := s.SearchMessages(ctx, ftsOpts)
if err != nil {
return nil, err
}
semanticResults, err := s.SearchMessagesSemantic(ctx, semanticOpts)
if err != nil {
return nil, err
}
return fuseSearchResults(ftsResults, semanticResults, limit), nil
}
type hybridSearchEntry struct {
result SearchResult
score float64
hasFTS bool
}
func fuseSearchResults(ftsResults, semanticResults []SearchResult, limit int) []SearchResult {
if limit <= 0 {
limit = 20
}
entries := make(map[string]*hybridSearchEntry, len(ftsResults)+len(semanticResults))
addResults := func(results []SearchResult, weight float64, fts bool) {
for index, result := range results {
entry := entries[result.MessageID]
if entry == nil {
entry = &hybridSearchEntry{result: result}
entries[result.MessageID] = entry
}
if fts {
entry.hasFTS = true
}
entry.score += weight / (rrfK + float64(index+1))
}
}
addResults(ftsResults, ftsRRFWeight, true)
addResults(semanticResults, semanticRRFWeight, false)
merged := make([]hybridSearchEntry, 0, len(entries))
for _, entry := range entries {
merged = append(merged, *entry)
}
sort.SliceStable(merged, func(i, j int) bool {
if merged[i].score != merged[j].score {
return merged[i].score > merged[j].score
}
if merged[i].hasFTS != merged[j].hasFTS {
return merged[i].hasFTS
}
if !merged[i].result.CreatedAt.Equal(merged[j].result.CreatedAt) {
return merged[i].result.CreatedAt.After(merged[j].result.CreatedAt)
}
return merged[i].result.MessageID > merged[j].result.MessageID
})
if len(merged) > limit {
merged = merged[:limit]
}
out := make([]SearchResult, 0, len(merged))
for _, entry := range merged {
out = append(out, entry.result)
}
return out
}
func (s *Store) hasCompatibleMessageEmbeddings(ctx context.Context, opts SemanticSearchOptions) (bool, error) {
queryCtx, cancel := withQueryTimeout(ctx)
defer cancel()
@ -311,6 +388,28 @@ func (s *Store) hasCompatibleMessageEmbeddings(ctx context.Context, opts Semanti
return exists == 1, err
}
func (s *Store) HasMessageEmbeddings(ctx context.Context, provider, model, inputVersion string) (bool, error) {
provider = strings.ToLower(strings.TrimSpace(provider))
model = strings.TrimSpace(model)
inputVersion = strings.TrimSpace(inputVersion)
if inputVersion == "" {
inputVersion = EmbeddingInputVersion
}
queryCtx, cancel := withQueryTimeout(ctx)
defer cancel()
var exists int
err := s.db.QueryRowContext(queryCtx, `
select exists(
select 1
from message_embeddings
where provider = ?
and model = ?
and input_version = ?
)
`, provider, model, inputVersion).Scan(&exists)
return exists == 1, err
}
func (s *Store) CheckMessageFTS(ctx context.Context) error {
db, cleanup, err := s.openReadOnlyDB()
if err != nil {

View File

@ -446,6 +446,124 @@ func TestSearchMessagesSemanticErrors(t *testing.T) {
require.ErrorContains(t, err, "stored embedding vector is zero")
}
func TestSearchMessagesHybridFusesAndDeduplicates(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
base := time.Date(2026, 4, 22, 12, 0, 0, 0, time.UTC)
require.NoError(t, s.UpsertGuild(ctx, GuildRecord{ID: "g1", Name: "Guild", RawJSON: `{}`}))
require.NoError(t, s.UpsertChannel(ctx, ChannelRecord{ID: "c1", GuildID: "g1", Kind: "text", Name: "general", RawJSON: `{}`}))
messages := []MessageRecord{
{
ID: "m3",
GuildID: "g1",
ChannelID: "c1",
ChannelName: "general",
AuthorID: "u1",
MessageType: 0,
CreatedAt: base.Format(time.RFC3339Nano),
Content: "panic stack trace",
NormalizedContent: "panic stack trace",
RawJSON: `{"author":{"username":"Alice"}}`,
},
{
ID: "m2",
GuildID: "g1",
ChannelID: "c1",
ChannelName: "general",
AuthorID: "u2",
MessageType: 0,
CreatedAt: base.Add(time.Minute).Format(time.RFC3339Nano),
Content: "database worker stalled",
NormalizedContent: "database worker stalled",
RawJSON: `{"author":{"username":"Bob"}}`,
},
{
ID: "m1",
GuildID: "g1",
ChannelID: "c1",
ChannelName: "general",
AuthorID: "u3",
MessageType: 0,
CreatedAt: base.Add(2 * time.Minute).Format(time.RFC3339Nano),
Content: "panic database lock",
NormalizedContent: "panic database lock",
RawJSON: `{"author":{"username":"Carol"}}`,
},
}
for _, message := range messages {
require.NoError(t, s.UpsertMessage(ctx, message))
}
require.NoError(t, insertTestEmbedding(ctx, s, "m1", "ollama", "nomic-embed-text", []float32{0.9, 0.1}))
require.NoError(t, insertTestEmbedding(ctx, s, "m2", "ollama", "nomic-embed-text", []float32{1, 0}))
require.NoError(t, insertTestEmbedding(ctx, s, "m3", "ollama", "nomic-embed-text", []float32{0, 1}))
results, err := s.SearchMessagesHybrid(ctx, SearchOptions{
Query: "lock",
GuildIDs: []string{"g1"},
Limit: 3,
}, SemanticSearchOptions{
QueryVector: []float32{1, 0},
Provider: "ollama",
Model: "nomic-embed-text",
InputVersion: EmbeddingInputVersion,
Dimensions: 2,
GuildIDs: []string{"g1"},
})
require.NoError(t, err)
require.Equal(t, []string{"m1", "m2", "m3"}, searchResultIDs(results))
}
func TestSearchMessagesHybridTieBreaksTowardFTS(t *testing.T) {
t.Parallel()
created := time.Date(2026, 4, 22, 12, 0, 0, 0, time.UTC)
results := fuseSearchResults(
[]SearchResult{{MessageID: "fts", CreatedAt: created}},
[]SearchResult{{MessageID: "semantic", CreatedAt: created.Add(time.Hour)}},
2,
)
require.Equal(t, []string{"fts", "semantic"}, searchResultIDs(results))
}
func TestSearchMessagesHybridPropagatesCorruptEmbeddings(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
require.NoError(t, s.UpsertMessage(ctx, MessageRecord{
ID: "m1",
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Content: "panic database lock",
NormalizedContent: "panic database lock",
RawJSON: `{}`,
}))
require.NoError(t, insertTestEmbeddingBlob(ctx, s, "m1", "ollama", "nomic-embed-text", 2, []byte{0, 0, 0, 0}))
_, err = s.SearchMessagesHybrid(ctx, SearchOptions{
Query: "panic",
Limit: 10,
}, SemanticSearchOptions{
QueryVector: []float32{1, 0},
Provider: "ollama",
Model: "nomic-embed-text",
InputVersion: EmbeddingInputVersion,
Dimensions: 2,
})
require.ErrorContains(t, err, "vector length mismatch")
}
func insertTestEmbedding(ctx context.Context, s *Store, messageID, provider, model string, vector []float32) error {
blob, err := EncodeEmbeddingVector(vector)
if err != nil {