feat(search): add hybrid message search
This commit is contained in:
parent
ffc622cd5e
commit
3ea1d4aa7f
@ -11,6 +11,7 @@ All notable changes to `discrawl` will be documented in this file.
|
||||
- `embed` now drains the queued embedding backlog in bounded batches, requeues safely on provider throttling, and drops stale stored vectors when messages no longer have embeddable content
|
||||
- Git-backed snapshots now keep embedding queue state and generated vectors local to each archive, so subscribers no longer inherit misleading embedding backlog metadata. (#38) Thanks @GaosCode.
|
||||
- semantic message search now ranks across the full compatible local vector set instead of only the newest candidate window. (#36) Thanks @GaosCode.
|
||||
- hybrid message search now fuses FTS with local semantic vectors while avoiding embedding-provider calls when no local vectors exist. (#37) Thanks @GaosCode.
|
||||
|
||||
## 0.3.0 - 2026-04-21
|
||||
|
||||
|
||||
@ -634,6 +634,91 @@ func TestSearchSemanticCommandUsesStoredEmbeddings(t *testing.T) {
|
||||
require.Equal(t, 2, requests)
|
||||
}
|
||||
|
||||
func TestSearchHybridCommandFusesResults(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.toml")
|
||||
dbPath := filepath.Join(dir, "discrawl.db")
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/embeddings", r.URL.Path)
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
Input []string `json:"input"`
|
||||
}
|
||||
assert.NoError(t, json.NewDecoder(r.Body).Decode(&req))
|
||||
assert.Equal(t, "local-model", req.Model)
|
||||
assert.Equal(t, []string{"panic"}, req.Input)
|
||||
_, _ = w.Write([]byte(`{"model":"local-model","data":[{"index":0,"embedding":[1,0]}]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := config.Default()
|
||||
cfg.DBPath = dbPath
|
||||
cfg.Search.DefaultMode = "hybrid"
|
||||
cfg.Search.Embeddings.Enabled = true
|
||||
cfg.Search.Embeddings.Provider = "openai_compatible"
|
||||
cfg.Search.Embeddings.Model = "local-model"
|
||||
cfg.Search.Embeddings.BaseURL = server.URL
|
||||
cfg.Search.Embeddings.APIKeyEnv = ""
|
||||
require.NoError(t, config.Write(cfgPath, cfg))
|
||||
|
||||
s, err := store.Open(ctx, dbPath)
|
||||
require.NoError(t, err)
|
||||
base := time.Date(2026, 4, 22, 12, 0, 0, 0, time.UTC)
|
||||
require.NoError(t, s.UpsertGuild(ctx, store.GuildRecord{ID: "g1", Name: "Guild", RawJSON: `{}`}))
|
||||
require.NoError(t, s.UpsertChannel(ctx, store.ChannelRecord{ID: "c1", GuildID: "g1", Kind: "text", Name: "general", RawJSON: `{}`}))
|
||||
require.NoError(t, s.UpsertMessage(ctx, store.MessageRecord{
|
||||
ID: "m3",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
ChannelName: "general",
|
||||
AuthorID: "u1",
|
||||
AuthorName: "Alice",
|
||||
MessageType: 0,
|
||||
CreatedAt: base.Format(time.RFC3339Nano),
|
||||
Content: "panic stack trace",
|
||||
NormalizedContent: "panic stack trace",
|
||||
RawJSON: `{"author":{"username":"Alice"}}`,
|
||||
}))
|
||||
require.NoError(t, s.UpsertMessage(ctx, store.MessageRecord{
|
||||
ID: "m2",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
ChannelName: "general",
|
||||
AuthorID: "u2",
|
||||
AuthorName: "Bob",
|
||||
MessageType: 0,
|
||||
CreatedAt: base.Add(time.Minute).Format(time.RFC3339Nano),
|
||||
Content: "database worker stalled",
|
||||
NormalizedContent: "database worker stalled",
|
||||
RawJSON: `{"author":{"username":"Bob"}}`,
|
||||
}))
|
||||
require.NoError(t, s.UpsertMessage(ctx, store.MessageRecord{
|
||||
ID: "m1",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
ChannelName: "general",
|
||||
AuthorID: "u3",
|
||||
AuthorName: "Carol",
|
||||
MessageType: 0,
|
||||
CreatedAt: base.Add(2 * time.Minute).Format(time.RFC3339Nano),
|
||||
Content: "panic database lock",
|
||||
NormalizedContent: "panic database lock",
|
||||
RawJSON: `{"author":{"username":"Carol"}}`,
|
||||
}))
|
||||
require.NoError(t, insertCLIEmbedding(ctx, s, "m1", "openai_compatible", "local-model", []float32{0.9, 0.1}))
|
||||
require.NoError(t, insertCLIEmbedding(ctx, s, "m2", "openai_compatible", "local-model", []float32{1, 0}))
|
||||
require.NoError(t, insertCLIEmbedding(ctx, s, "m3", "openai_compatible", "local-model", []float32{0, 1}))
|
||||
require.NoError(t, s.Close())
|
||||
|
||||
var out bytes.Buffer
|
||||
require.NoError(t, Run(ctx, []string{"--config", cfgPath, "search", "--limit", "3", "panic"}, &out, &bytes.Buffer{}))
|
||||
require.Contains(t, out.String(), "panic database lock")
|
||||
require.Contains(t, out.String(), "database worker stalled")
|
||||
require.Contains(t, out.String(), "panic stack trace")
|
||||
}
|
||||
|
||||
func TestSearchSemanticCommandErrors(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dir := t.TempDir()
|
||||
@ -651,10 +736,6 @@ func TestSearchSemanticCommandErrors(t *testing.T) {
|
||||
require.Equal(t, 2, ExitCode(err))
|
||||
require.ErrorContains(t, err, `unsupported search mode "bogus"`)
|
||||
|
||||
err = Run(ctx, []string{"--config", cfgPath, "search", "--mode", "hybrid", "cats"}, &bytes.Buffer{}, &bytes.Buffer{})
|
||||
require.Equal(t, 1, ExitCode(err))
|
||||
require.ErrorContains(t, err, "hybrid search is not implemented yet")
|
||||
|
||||
err = Run(ctx, []string{"--config", cfgPath, "search", "--mode", "semantic", "cats"}, &bytes.Buffer{}, &bytes.Buffer{})
|
||||
require.Equal(t, 1, ExitCode(err))
|
||||
require.ErrorContains(t, err, "embeddings are disabled")
|
||||
@ -675,6 +756,81 @@ func TestSearchSemanticCommandErrors(t *testing.T) {
|
||||
require.ErrorContains(t, err, "embedding query failed")
|
||||
}
|
||||
|
||||
func TestSearchHybridCommandFallsBackToFTS(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.toml")
|
||||
dbPath := filepath.Join(dir, "discrawl.db")
|
||||
|
||||
cfg := config.Default()
|
||||
cfg.DBPath = dbPath
|
||||
require.NoError(t, config.Write(cfgPath, cfg))
|
||||
|
||||
s, err := store.Open(ctx, dbPath)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, s.UpsertGuild(ctx, store.GuildRecord{ID: "g1", Name: "Guild", RawJSON: `{}`}))
|
||||
require.NoError(t, s.UpsertChannel(ctx, store.ChannelRecord{ID: "c1", GuildID: "g1", Kind: "text", Name: "general", RawJSON: `{}`}))
|
||||
require.NoError(t, s.UpsertMessage(ctx, store.MessageRecord{
|
||||
ID: "m1",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
ChannelName: "general",
|
||||
AuthorID: "u1",
|
||||
AuthorName: "Alice",
|
||||
MessageType: 0,
|
||||
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
|
||||
Content: "panic exact match",
|
||||
NormalizedContent: "panic exact match",
|
||||
RawJSON: `{"author":{"username":"Alice"}}`,
|
||||
}))
|
||||
require.NoError(t, s.Close())
|
||||
|
||||
var out bytes.Buffer
|
||||
require.NoError(t, Run(ctx, []string{"--config", cfgPath, "search", "--mode", "hybrid", "panic"}, &out, &bytes.Buffer{}))
|
||||
require.Contains(t, out.String(), "panic exact match")
|
||||
|
||||
okRequests := 0
|
||||
okServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
okRequests++
|
||||
_, _ = w.Write([]byte(`{"model":"local-model","data":[{"index":0,"embedding":[1,0]}]}`))
|
||||
}))
|
||||
defer okServer.Close()
|
||||
cfg.Search.Embeddings.Enabled = true
|
||||
cfg.Search.Embeddings.Provider = "openai_compatible"
|
||||
cfg.Search.Embeddings.Model = "local-model"
|
||||
cfg.Search.Embeddings.BaseURL = okServer.URL
|
||||
cfg.Search.Embeddings.APIKeyEnv = ""
|
||||
require.NoError(t, config.Write(cfgPath, cfg))
|
||||
|
||||
out.Reset()
|
||||
require.NoError(t, Run(ctx, []string{"--config", cfgPath, "search", "--mode", "hybrid", "panic"}, &out, &bytes.Buffer{}))
|
||||
require.Contains(t, out.String(), "panic exact match")
|
||||
require.Equal(t, 0, okRequests)
|
||||
|
||||
s, err = store.Open(ctx, dbPath)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, insertCLIEmbedding(ctx, s, "m1", "openai_compatible", "local-model", []float32{1, 0}))
|
||||
require.NoError(t, s.Close())
|
||||
|
||||
failedRequests := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
failedRequests++
|
||||
http.Error(w, "nope", http.StatusInternalServerError)
|
||||
}))
|
||||
defer server.Close()
|
||||
cfg.Search.Embeddings.Enabled = true
|
||||
cfg.Search.Embeddings.Provider = "openai_compatible"
|
||||
cfg.Search.Embeddings.Model = "local-model"
|
||||
cfg.Search.Embeddings.BaseURL = server.URL
|
||||
cfg.Search.Embeddings.APIKeyEnv = ""
|
||||
require.NoError(t, config.Write(cfgPath, cfg))
|
||||
|
||||
out.Reset()
|
||||
require.NoError(t, Run(ctx, []string{"--config", cfgPath, "search", "--mode", "hybrid", "panic"}, &out, &bytes.Buffer{}))
|
||||
require.Contains(t, out.String(), "panic exact match")
|
||||
require.Equal(t, 1, failedRequests)
|
||||
}
|
||||
|
||||
func insertCLIEmbedding(ctx context.Context, s *store.Store, messageID, provider, model string, vector []float32) error {
|
||||
blob, err := store.EncodeEmbeddingVector(vector)
|
||||
if err != nil {
|
||||
|
||||
@ -2,6 +2,7 @@ package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -51,15 +52,57 @@ func (r *runtime) runSearch(args []string) error {
|
||||
}
|
||||
return r.print(results)
|
||||
case "hybrid":
|
||||
return fmt.Errorf("hybrid search is not implemented yet")
|
||||
results, err := r.searchMessagesHybrid(opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return r.print(results)
|
||||
default:
|
||||
return usageErr(fmt.Errorf("unsupported search mode %q", *mode))
|
||||
}
|
||||
}
|
||||
|
||||
func (r *runtime) searchMessagesSemantic(opts store.SearchOptions) ([]store.SearchResult, error) {
|
||||
semanticOpts, err := r.semanticSearchOptions(opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.store.SearchMessagesSemantic(r.ctx, semanticOpts)
|
||||
}
|
||||
|
||||
func (r *runtime) searchMessagesHybrid(opts store.SearchOptions) ([]store.SearchResult, error) {
|
||||
if !r.cfg.Search.Embeddings.Enabled {
|
||||
return nil, fmt.Errorf("embeddings are disabled; enable [search.embeddings] first")
|
||||
return r.store.SearchMessages(r.ctx, opts)
|
||||
}
|
||||
hasEmbeddings, err := r.store.HasMessageEmbeddings(
|
||||
r.ctx,
|
||||
r.cfg.Search.Embeddings.Provider,
|
||||
r.cfg.Search.Embeddings.Model,
|
||||
store.EmbeddingInputVersion,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !hasEmbeddings {
|
||||
return r.store.SearchMessages(r.ctx, opts)
|
||||
}
|
||||
semanticOpts, err := r.semanticSearchOptions(opts)
|
||||
if err != nil {
|
||||
return r.store.SearchMessages(r.ctx, opts)
|
||||
}
|
||||
results, err := r.store.SearchMessagesHybrid(r.ctx, opts, semanticOpts)
|
||||
if err != nil {
|
||||
if hybridSemanticUnavailable(err) {
|
||||
return r.store.SearchMessages(r.ctx, opts)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (r *runtime) semanticSearchOptions(opts store.SearchOptions) (store.SemanticSearchOptions, error) {
|
||||
if !r.cfg.Search.Embeddings.Enabled {
|
||||
return store.SemanticSearchOptions{}, fmt.Errorf("embeddings are disabled; enable [search.embeddings] first")
|
||||
}
|
||||
providerFactory := r.newEmbed
|
||||
if providerFactory == nil {
|
||||
@ -69,21 +112,21 @@ func (r *runtime) searchMessagesSemantic(opts store.SearchOptions) ([]store.Sear
|
||||
}
|
||||
provider, err := providerFactory(r.cfg.Search.Embeddings)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create embedding provider: %w", err)
|
||||
return store.SemanticSearchOptions{}, fmt.Errorf("create embedding provider: %w", err)
|
||||
}
|
||||
batch, err := provider.Embed(r.ctx, []string{opts.Query})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("embedding query failed: %w", err)
|
||||
return store.SemanticSearchOptions{}, fmt.Errorf("embedding query failed: %w", err)
|
||||
}
|
||||
if len(batch.Vectors) != 1 {
|
||||
return nil, fmt.Errorf("embedding query returned %d vectors for 1 input", len(batch.Vectors))
|
||||
return store.SemanticSearchOptions{}, fmt.Errorf("embedding query returned %d vectors for 1 input", len(batch.Vectors))
|
||||
}
|
||||
queryVector := batch.Vectors[0]
|
||||
dimensions := batch.Dimensions
|
||||
if dimensions == 0 {
|
||||
dimensions = len(queryVector)
|
||||
}
|
||||
return r.store.SearchMessagesSemantic(r.ctx, store.SemanticSearchOptions{
|
||||
return store.SemanticSearchOptions{
|
||||
QueryVector: queryVector,
|
||||
Provider: r.cfg.Search.Embeddings.Provider,
|
||||
Model: r.cfg.Search.Embeddings.Model,
|
||||
@ -94,7 +137,11 @@ func (r *runtime) searchMessagesSemantic(opts store.SearchOptions) ([]store.Sear
|
||||
Author: opts.Author,
|
||||
Limit: opts.Limit,
|
||||
IncludeEmpty: opts.IncludeEmpty,
|
||||
})
|
||||
}, nil
|
||||
}
|
||||
|
||||
func hybridSemanticUnavailable(err error) bool {
|
||||
return errors.Is(err, store.ErrNoCompatibleEmbeddings) || strings.HasPrefix(err.Error(), "semantic query embedding ")
|
||||
}
|
||||
|
||||
func (r *runtime) runSQL(args []string) error {
|
||||
|
||||
@ -20,6 +20,9 @@ const (
|
||||
searchCandidateCap = 5000
|
||||
searchCandidateMultiple = 20
|
||||
messageFTSHealthProbe = "__discrawl_probe__"
|
||||
rrfK = 60.0
|
||||
ftsRRFWeight = 1.0
|
||||
semanticRRFWeight = 1.0
|
||||
)
|
||||
|
||||
var ErrNoCompatibleEmbeddings = errors.New("no compatible message embeddings for provider/model/input version; run discrawl embed --rebuild")
|
||||
@ -294,6 +297,80 @@ func semanticScoreLess(left, right semanticScoredResult) bool {
|
||||
return left.result.MessageID > right.result.MessageID
|
||||
}
|
||||
|
||||
func (s *Store) SearchMessagesHybrid(ctx context.Context, opts SearchOptions, semanticOpts SemanticSearchOptions) ([]SearchResult, error) {
|
||||
limit := opts.Limit
|
||||
if limit <= 0 {
|
||||
limit = 20
|
||||
}
|
||||
candidateLimit := searchCandidateLimit(limit)
|
||||
ftsOpts := opts
|
||||
ftsOpts.Limit = candidateLimit
|
||||
semanticOpts.Limit = candidateLimit
|
||||
|
||||
ftsResults, err := s.SearchMessages(ctx, ftsOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
semanticResults, err := s.SearchMessagesSemantic(ctx, semanticOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return fuseSearchResults(ftsResults, semanticResults, limit), nil
|
||||
}
|
||||
|
||||
type hybridSearchEntry struct {
|
||||
result SearchResult
|
||||
score float64
|
||||
hasFTS bool
|
||||
}
|
||||
|
||||
func fuseSearchResults(ftsResults, semanticResults []SearchResult, limit int) []SearchResult {
|
||||
if limit <= 0 {
|
||||
limit = 20
|
||||
}
|
||||
entries := make(map[string]*hybridSearchEntry, len(ftsResults)+len(semanticResults))
|
||||
addResults := func(results []SearchResult, weight float64, fts bool) {
|
||||
for index, result := range results {
|
||||
entry := entries[result.MessageID]
|
||||
if entry == nil {
|
||||
entry = &hybridSearchEntry{result: result}
|
||||
entries[result.MessageID] = entry
|
||||
}
|
||||
if fts {
|
||||
entry.hasFTS = true
|
||||
}
|
||||
entry.score += weight / (rrfK + float64(index+1))
|
||||
}
|
||||
}
|
||||
addResults(ftsResults, ftsRRFWeight, true)
|
||||
addResults(semanticResults, semanticRRFWeight, false)
|
||||
|
||||
merged := make([]hybridSearchEntry, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
merged = append(merged, *entry)
|
||||
}
|
||||
sort.SliceStable(merged, func(i, j int) bool {
|
||||
if merged[i].score != merged[j].score {
|
||||
return merged[i].score > merged[j].score
|
||||
}
|
||||
if merged[i].hasFTS != merged[j].hasFTS {
|
||||
return merged[i].hasFTS
|
||||
}
|
||||
if !merged[i].result.CreatedAt.Equal(merged[j].result.CreatedAt) {
|
||||
return merged[i].result.CreatedAt.After(merged[j].result.CreatedAt)
|
||||
}
|
||||
return merged[i].result.MessageID > merged[j].result.MessageID
|
||||
})
|
||||
if len(merged) > limit {
|
||||
merged = merged[:limit]
|
||||
}
|
||||
out := make([]SearchResult, 0, len(merged))
|
||||
for _, entry := range merged {
|
||||
out = append(out, entry.result)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *Store) hasCompatibleMessageEmbeddings(ctx context.Context, opts SemanticSearchOptions) (bool, error) {
|
||||
queryCtx, cancel := withQueryTimeout(ctx)
|
||||
defer cancel()
|
||||
@ -311,6 +388,28 @@ func (s *Store) hasCompatibleMessageEmbeddings(ctx context.Context, opts Semanti
|
||||
return exists == 1, err
|
||||
}
|
||||
|
||||
func (s *Store) HasMessageEmbeddings(ctx context.Context, provider, model, inputVersion string) (bool, error) {
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
model = strings.TrimSpace(model)
|
||||
inputVersion = strings.TrimSpace(inputVersion)
|
||||
if inputVersion == "" {
|
||||
inputVersion = EmbeddingInputVersion
|
||||
}
|
||||
queryCtx, cancel := withQueryTimeout(ctx)
|
||||
defer cancel()
|
||||
var exists int
|
||||
err := s.db.QueryRowContext(queryCtx, `
|
||||
select exists(
|
||||
select 1
|
||||
from message_embeddings
|
||||
where provider = ?
|
||||
and model = ?
|
||||
and input_version = ?
|
||||
)
|
||||
`, provider, model, inputVersion).Scan(&exists)
|
||||
return exists == 1, err
|
||||
}
|
||||
|
||||
func (s *Store) CheckMessageFTS(ctx context.Context) error {
|
||||
db, cleanup, err := s.openReadOnlyDB()
|
||||
if err != nil {
|
||||
|
||||
@ -446,6 +446,124 @@ func TestSearchMessagesSemanticErrors(t *testing.T) {
|
||||
require.ErrorContains(t, err, "stored embedding vector is zero")
|
||||
}
|
||||
|
||||
func TestSearchMessagesHybridFusesAndDeduplicates(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = s.Close() }()
|
||||
|
||||
base := time.Date(2026, 4, 22, 12, 0, 0, 0, time.UTC)
|
||||
require.NoError(t, s.UpsertGuild(ctx, GuildRecord{ID: "g1", Name: "Guild", RawJSON: `{}`}))
|
||||
require.NoError(t, s.UpsertChannel(ctx, ChannelRecord{ID: "c1", GuildID: "g1", Kind: "text", Name: "general", RawJSON: `{}`}))
|
||||
|
||||
messages := []MessageRecord{
|
||||
{
|
||||
ID: "m3",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
ChannelName: "general",
|
||||
AuthorID: "u1",
|
||||
MessageType: 0,
|
||||
CreatedAt: base.Format(time.RFC3339Nano),
|
||||
Content: "panic stack trace",
|
||||
NormalizedContent: "panic stack trace",
|
||||
RawJSON: `{"author":{"username":"Alice"}}`,
|
||||
},
|
||||
{
|
||||
ID: "m2",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
ChannelName: "general",
|
||||
AuthorID: "u2",
|
||||
MessageType: 0,
|
||||
CreatedAt: base.Add(time.Minute).Format(time.RFC3339Nano),
|
||||
Content: "database worker stalled",
|
||||
NormalizedContent: "database worker stalled",
|
||||
RawJSON: `{"author":{"username":"Bob"}}`,
|
||||
},
|
||||
{
|
||||
ID: "m1",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
ChannelName: "general",
|
||||
AuthorID: "u3",
|
||||
MessageType: 0,
|
||||
CreatedAt: base.Add(2 * time.Minute).Format(time.RFC3339Nano),
|
||||
Content: "panic database lock",
|
||||
NormalizedContent: "panic database lock",
|
||||
RawJSON: `{"author":{"username":"Carol"}}`,
|
||||
},
|
||||
}
|
||||
for _, message := range messages {
|
||||
require.NoError(t, s.UpsertMessage(ctx, message))
|
||||
}
|
||||
require.NoError(t, insertTestEmbedding(ctx, s, "m1", "ollama", "nomic-embed-text", []float32{0.9, 0.1}))
|
||||
require.NoError(t, insertTestEmbedding(ctx, s, "m2", "ollama", "nomic-embed-text", []float32{1, 0}))
|
||||
require.NoError(t, insertTestEmbedding(ctx, s, "m3", "ollama", "nomic-embed-text", []float32{0, 1}))
|
||||
|
||||
results, err := s.SearchMessagesHybrid(ctx, SearchOptions{
|
||||
Query: "lock",
|
||||
GuildIDs: []string{"g1"},
|
||||
Limit: 3,
|
||||
}, SemanticSearchOptions{
|
||||
QueryVector: []float32{1, 0},
|
||||
Provider: "ollama",
|
||||
Model: "nomic-embed-text",
|
||||
InputVersion: EmbeddingInputVersion,
|
||||
Dimensions: 2,
|
||||
GuildIDs: []string{"g1"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"m1", "m2", "m3"}, searchResultIDs(results))
|
||||
}
|
||||
|
||||
func TestSearchMessagesHybridTieBreaksTowardFTS(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
created := time.Date(2026, 4, 22, 12, 0, 0, 0, time.UTC)
|
||||
results := fuseSearchResults(
|
||||
[]SearchResult{{MessageID: "fts", CreatedAt: created}},
|
||||
[]SearchResult{{MessageID: "semantic", CreatedAt: created.Add(time.Hour)}},
|
||||
2,
|
||||
)
|
||||
require.Equal(t, []string{"fts", "semantic"}, searchResultIDs(results))
|
||||
}
|
||||
|
||||
func TestSearchMessagesHybridPropagatesCorruptEmbeddings(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = s.Close() }()
|
||||
|
||||
require.NoError(t, s.UpsertMessage(ctx, MessageRecord{
|
||||
ID: "m1",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
MessageType: 0,
|
||||
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
|
||||
Content: "panic database lock",
|
||||
NormalizedContent: "panic database lock",
|
||||
RawJSON: `{}`,
|
||||
}))
|
||||
require.NoError(t, insertTestEmbeddingBlob(ctx, s, "m1", "ollama", "nomic-embed-text", 2, []byte{0, 0, 0, 0}))
|
||||
|
||||
_, err = s.SearchMessagesHybrid(ctx, SearchOptions{
|
||||
Query: "panic",
|
||||
Limit: 10,
|
||||
}, SemanticSearchOptions{
|
||||
QueryVector: []float32{1, 0},
|
||||
Provider: "ollama",
|
||||
Model: "nomic-embed-text",
|
||||
InputVersion: EmbeddingInputVersion,
|
||||
Dimensions: 2,
|
||||
})
|
||||
require.ErrorContains(t, err, "vector length mismatch")
|
||||
}
|
||||
|
||||
func insertTestEmbedding(ctx context.Context, s *Store, messageID, provider, model string, vector []float32) error {
|
||||
blob, err := EncodeEmbeddingVector(vector)
|
||||
if err != nil {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user