discrawl/internal/store/query.go
2026-04-24 20:36:03 +01:00

916 lines
25 KiB
Go

package store
import (
"context"
"database/sql"
"errors"
"fmt"
"math"
"os"
"sort"
"strings"
"time"
)
const (
queryTimeout = 15 * time.Second
semanticQueryTimeout = 2 * time.Minute
queryRowLimit = 50000
searchCandidateFloor = 200
searchCandidateCap = 5000
searchCandidateMultiple = 20
messageFTSHealthProbe = "__discrawl_probe__"
rrfK = 60.0
ftsRRFWeight = 1.0
semanticRRFWeight = 1.0
)
var ErrNoCompatibleEmbeddings = errors.New("no compatible message embeddings for provider/model/input version; run discrawl embed --rebuild")
type SemanticSearchOptions struct {
QueryVector []float32
Provider string
Model string
InputVersion string
Dimensions int
GuildIDs []string
Channel string
Author string
Limit int
IncludeEmpty bool
}
func (s *Store) GetSyncState(ctx context.Context, scope string) (string, error) {
var cursor sql.NullString
err := s.db.QueryRowContext(ctx, `select cursor from sync_state where scope = ?`, scope).Scan(&cursor)
if err != nil {
if err == sql.ErrNoRows {
return "", nil
}
return "", err
}
return cursor.String, nil
}
func (s *Store) ChannelMessageBounds(ctx context.Context, channelID string) (string, string, error) {
var oldest sql.NullString
var newest sql.NullString
if err := s.db.QueryRowContext(ctx, `
select min(id), max(id)
from messages
where channel_id = ?
`, channelID).Scan(&oldest, &newest); err != nil {
return "", "", err
}
return oldest.String, newest.String, nil
}
func (s *Store) SearchMessages(ctx context.Context, opts SearchOptions) ([]SearchResult, error) {
if strings.TrimSpace(opts.Query) == "" {
return nil, nil
}
if opts.Limit <= 0 {
opts.Limit = 20
}
args := []any{normalizeFTSQuery(opts.Query)}
clauses := []string{"message_fts match ?"}
if len(opts.GuildIDs) > 0 {
clauses = append(clauses, "message_fts.guild_id in ("+placeholders(len(opts.GuildIDs))+")")
for _, guildID := range opts.GuildIDs {
args = append(args, guildID)
}
}
if strings.TrimSpace(opts.Channel) != "" {
clauses = append(clauses, "(message_fts.channel_id = ? or message_fts.channel_name like ?)")
args = append(args, opts.Channel, "%"+opts.Channel+"%")
}
if strings.TrimSpace(opts.Author) != "" {
clauses = append(clauses, "(message_fts.author_id = ? or message_fts.author_name like ?)")
args = append(args, opts.Author, "%"+opts.Author+"%")
}
args = append(args, searchCandidateLimit(opts.Limit), opts.Limit)
query := `
with recent_matches as (
select
rowid,
message_id,
guild_id,
channel_id,
author_id,
coalesce(author_name, '') as author_name,
coalesce(channel_name, '') as channel_name
from message_fts
where ` + strings.Join(clauses, " and ") + `
order by rowid desc
limit ?
)
select
m.id, m.guild_id, m.channel_id, coalesce(c.name, recent_matches.channel_name),
coalesce(m.author_id, ''), recent_matches.author_name,
case
when trim(coalesce(m.content, '')) <> '' then m.content
else m.normalized_content
end,
m.created_at
from recent_matches
join messages m on m.id = recent_matches.message_id
left join channels c on c.id = m.channel_id
where (? or trim(coalesce(m.normalized_content, '')) <> '')
order by recent_matches.rowid desc
limit ?
`
queryCtx, cancel := withQueryTimeout(ctx)
defer cancel()
rows, err := s.db.QueryContext(queryCtx, query, append(args[:len(args)-1], opts.IncludeEmpty, args[len(args)-1])...)
if err != nil {
fallbackCtx, fallbackCancel := withQueryTimeout(ctx)
defer fallbackCancel()
return s.searchFallback(fallbackCtx, opts)
}
defer func() { _ = rows.Close() }()
var out []SearchResult
for rows.Next() {
var row SearchResult
var created string
if err := rows.Scan(&row.MessageID, &row.GuildID, &row.ChannelID, &row.ChannelName, &row.AuthorID, &row.AuthorName, &row.Content, &created); err != nil {
return nil, err
}
row.CreatedAt = parseTime(created)
out = append(out, row)
}
return out, rows.Err()
}
func (s *Store) SearchMessagesSemantic(ctx context.Context, opts SemanticSearchOptions) ([]SearchResult, error) {
opts.Provider = strings.ToLower(strings.TrimSpace(opts.Provider))
opts.Model = strings.TrimSpace(opts.Model)
opts.InputVersion = strings.TrimSpace(opts.InputVersion)
if opts.InputVersion == "" {
opts.InputVersion = EmbeddingInputVersion
}
if opts.Limit <= 0 {
opts.Limit = 20
}
if len(opts.QueryVector) == 0 {
return nil, errors.New("semantic query embedding returned an empty vector")
}
if opts.Dimensions <= 0 {
opts.Dimensions = len(opts.QueryVector)
}
if len(opts.QueryVector) != opts.Dimensions {
return nil, fmt.Errorf("semantic query embedding dimensions mismatch: got %d want %d", len(opts.QueryVector), opts.Dimensions)
}
queryNorm := vectorNorm(opts.QueryVector)
if queryNorm == 0 {
return nil, errors.New("semantic query embedding returned a zero vector")
}
clauses := []string{
"e.provider = ?",
"e.model = ?",
"e.input_version = ?",
"e.dimensions = ?",
}
args := []any{opts.Provider, opts.Model, opts.InputVersion, opts.Dimensions}
if len(opts.GuildIDs) > 0 {
clauses = append(clauses, "m.guild_id in ("+placeholders(len(opts.GuildIDs))+")")
for _, guildID := range opts.GuildIDs {
args = append(args, guildID)
}
}
if strings.TrimSpace(opts.Channel) != "" {
clauses = append(clauses, "(m.channel_id = ? or c.name like ?)")
args = append(args, opts.Channel, "%"+opts.Channel+"%")
}
authorExpr := `coalesce(
json_extract(m.raw_json, '$.member.nick'),
json_extract(m.raw_json, '$.author.global_name'),
json_extract(m.raw_json, '$.author.username'),
''
)`
if strings.TrimSpace(opts.Author) != "" {
clauses = append(clauses, "(m.author_id = ? or "+authorExpr+" like ?)")
args = append(args, opts.Author, "%"+opts.Author+"%")
}
if !opts.IncludeEmpty {
clauses = append(clauses, "trim(coalesce(m.normalized_content, '')) <> ''")
}
queryCtx, cancel := context.WithTimeout(ctx, semanticQueryTimeout)
defer cancel()
rows, err := s.db.QueryContext(queryCtx, `
select
m.id,
m.guild_id,
m.channel_id,
coalesce(c.name, ''),
coalesce(m.author_id, ''),
`+authorExpr+`,
case
when trim(coalesce(m.content, '')) <> '' then m.content
else m.normalized_content
end,
m.created_at,
e.dimensions,
e.embedding_blob
from message_embeddings e
join messages m on m.id = e.message_id
left join channels c on c.id = m.channel_id
where `+strings.Join(clauses, " and ")+`
`, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
scored := make([]semanticScoredResult, 0, opts.Limit)
for rows.Next() {
var (
row SearchResult
created string
dimensions int
blob []byte
)
if err := rows.Scan(&row.MessageID, &row.GuildID, &row.ChannelID, &row.ChannelName, &row.AuthorID, &row.AuthorName, &row.Content, &created, &dimensions, &blob); err != nil {
return nil, err
}
if dimensions != opts.Dimensions {
return nil, fmt.Errorf("stored embedding dimensions mismatch for message %s: got %d want %d", row.MessageID, dimensions, opts.Dimensions)
}
vector, err := DecodeEmbeddingVector(blob)
if err != nil {
return nil, fmt.Errorf("decode embedding for message %s: %w", row.MessageID, err)
}
if len(vector) != dimensions {
return nil, fmt.Errorf("stored embedding vector length mismatch for message %s: got %d want %d", row.MessageID, len(vector), dimensions)
}
score, err := cosineSimilarity(opts.QueryVector, queryNorm, vector)
if err != nil {
return nil, fmt.Errorf("score embedding for message %s: %w", row.MessageID, err)
}
row.CreatedAt = parseTime(created)
item := semanticScoredResult{result: row, score: score}
insertAt := sort.Search(len(scored), func(i int) bool {
return semanticScoreLess(item, scored[i])
})
if insertAt >= opts.Limit {
continue
}
scored = append(scored, semanticScoredResult{})
copy(scored[insertAt+1:], scored[insertAt:])
scored[insertAt] = item
if len(scored) > opts.Limit {
scored = scored[:opts.Limit]
}
}
if err := rows.Err(); err != nil {
return nil, err
}
if len(scored) == 0 {
compatible, err := s.hasCompatibleMessageEmbeddings(ctx, opts)
if err != nil {
return nil, err
}
if !compatible {
return nil, ErrNoCompatibleEmbeddings
}
return []SearchResult{}, nil
}
out := make([]SearchResult, 0, len(scored))
for _, item := range scored {
out = append(out, item.result)
}
return out, nil
}
type semanticScoredResult struct {
result SearchResult
score float64
}
func semanticScoreLess(left, right semanticScoredResult) bool {
if left.score != right.score {
return left.score > right.score
}
if !left.result.CreatedAt.Equal(right.result.CreatedAt) {
return left.result.CreatedAt.After(right.result.CreatedAt)
}
return left.result.MessageID > right.result.MessageID
}
func (s *Store) SearchMessagesHybrid(ctx context.Context, opts SearchOptions, semanticOpts SemanticSearchOptions) ([]SearchResult, error) {
limit := opts.Limit
if limit <= 0 {
limit = 20
}
candidateLimit := searchCandidateLimit(limit)
ftsOpts := opts
ftsOpts.Limit = candidateLimit
semanticOpts.Limit = candidateLimit
ftsResults, err := s.SearchMessages(ctx, ftsOpts)
if err != nil {
return nil, err
}
semanticResults, err := s.SearchMessagesSemantic(ctx, semanticOpts)
if err != nil {
return nil, err
}
return fuseSearchResults(ftsResults, semanticResults, limit), nil
}
type hybridSearchEntry struct {
result SearchResult
score float64
hasFTS bool
}
func fuseSearchResults(ftsResults, semanticResults []SearchResult, limit int) []SearchResult {
if limit <= 0 {
limit = 20
}
entries := make(map[string]*hybridSearchEntry, len(ftsResults)+len(semanticResults))
addResults := func(results []SearchResult, weight float64, fts bool) {
for index, result := range results {
entry := entries[result.MessageID]
if entry == nil {
entry = &hybridSearchEntry{result: result}
entries[result.MessageID] = entry
}
if fts {
entry.hasFTS = true
}
entry.score += weight / (rrfK + float64(index+1))
}
}
addResults(ftsResults, ftsRRFWeight, true)
addResults(semanticResults, semanticRRFWeight, false)
merged := make([]hybridSearchEntry, 0, len(entries))
for _, entry := range entries {
merged = append(merged, *entry)
}
sort.SliceStable(merged, func(i, j int) bool {
if merged[i].score != merged[j].score {
return merged[i].score > merged[j].score
}
if merged[i].hasFTS != merged[j].hasFTS {
return merged[i].hasFTS
}
if !merged[i].result.CreatedAt.Equal(merged[j].result.CreatedAt) {
return merged[i].result.CreatedAt.After(merged[j].result.CreatedAt)
}
return merged[i].result.MessageID > merged[j].result.MessageID
})
if len(merged) > limit {
merged = merged[:limit]
}
out := make([]SearchResult, 0, len(merged))
for _, entry := range merged {
out = append(out, entry.result)
}
return out
}
func (s *Store) hasCompatibleMessageEmbeddings(ctx context.Context, opts SemanticSearchOptions) (bool, error) {
queryCtx, cancel := withQueryTimeout(ctx)
defer cancel()
var exists int
err := s.db.QueryRowContext(queryCtx, `
select exists(
select 1
from message_embeddings
where provider = ?
and model = ?
and input_version = ?
and dimensions = ?
)
`, opts.Provider, opts.Model, opts.InputVersion, opts.Dimensions).Scan(&exists)
return exists == 1, err
}
func (s *Store) HasMessageEmbeddings(ctx context.Context, provider, model, inputVersion string) (bool, error) {
provider = strings.ToLower(strings.TrimSpace(provider))
model = strings.TrimSpace(model)
inputVersion = strings.TrimSpace(inputVersion)
if inputVersion == "" {
inputVersion = EmbeddingInputVersion
}
queryCtx, cancel := withQueryTimeout(ctx)
defer cancel()
var exists int
err := s.db.QueryRowContext(queryCtx, `
select exists(
select 1
from message_embeddings
where provider = ?
and model = ?
and input_version = ?
)
`, provider, model, inputVersion).Scan(&exists)
return exists == 1, err
}
func (s *Store) CheckMessageFTS(ctx context.Context) error {
db, cleanup, err := s.openReadOnlyDB()
if err != nil {
return err
}
if cleanup != nil {
defer cleanup()
}
queryCtx, cancel := withQueryTimeout(ctx)
defer cancel()
var rowID sql.NullInt64
err = db.QueryRowContext(
queryCtx,
`select rowid from message_fts where message_fts match ? limit 1`,
messageFTSHealthProbe,
).Scan(&rowID)
if errors.Is(err, sql.ErrNoRows) {
return nil
}
return err
}
func (s *Store) searchFallback(ctx context.Context, opts SearchOptions) ([]SearchResult, error) {
args := []any{"%" + opts.Query + "%"}
clauses := []string{"m.normalized_content like ?"}
if len(opts.GuildIDs) > 0 {
clauses = append(clauses, "m.guild_id in ("+placeholders(len(opts.GuildIDs))+")")
for _, guildID := range opts.GuildIDs {
args = append(args, guildID)
}
}
if strings.TrimSpace(opts.Channel) != "" {
clauses = append(clauses, "(m.channel_id = ? or c.name like ?)")
args = append(args, opts.Channel, "%"+opts.Channel+"%")
}
if strings.TrimSpace(opts.Author) != "" {
clauses = append(clauses, "(m.author_id = ? or m.raw_json like ?)")
args = append(args, opts.Author, "%"+opts.Author+"%")
}
if !opts.IncludeEmpty {
clauses = append(clauses, "trim(coalesce(m.normalized_content, '')) <> ''")
}
args = append(args, opts.Limit)
rows, err := s.db.QueryContext(ctx, `
select
m.id,
m.guild_id,
m.channel_id,
coalesce(c.name, ''),
coalesce(m.author_id, ''),
'',
case
when trim(coalesce(m.content, '')) <> '' then m.content
else m.normalized_content
end,
m.created_at
from messages m
left join channels c on c.id = m.channel_id
where `+strings.Join(clauses, " and ")+`
order by m.created_at desc
limit ?
`, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
var out []SearchResult
for rows.Next() {
var row SearchResult
var created string
if err := rows.Scan(&row.MessageID, &row.GuildID, &row.ChannelID, &row.ChannelName, &row.AuthorID, &row.AuthorName, &row.Content, &created); err != nil {
return nil, err
}
row.CreatedAt = parseTime(created)
out = append(out, row)
}
return out, rows.Err()
}
func cosineSimilarity(query []float32, queryNorm float64, vector []float32) (float64, error) {
if len(vector) != len(query) {
return 0, fmt.Errorf("dimensions mismatch: got %d want %d", len(vector), len(query))
}
vectorNorm := vectorNorm(vector)
if vectorNorm == 0 {
return 0, errors.New("stored embedding vector is zero")
}
var dot float64
for i := range query {
dot += float64(query[i]) * float64(vector[i])
}
return dot / (queryNorm * vectorNorm), nil
}
func vectorNorm(vector []float32) float64 {
var sum float64
for _, value := range vector {
sum += float64(value) * float64(value)
}
return math.Sqrt(sum)
}
func (s *Store) Members(ctx context.Context, guildID, query string, limit int) ([]MemberRow, error) {
if strings.TrimSpace(query) != "" {
return s.searchMembers(ctx, guildID, query, limit)
}
if limit <= 0 {
limit = 100
}
args := []any{}
clauses := []string{"1=1"}
if guildID != "" {
clauses = append(clauses, "guild_id = ?")
args = append(args, guildID)
}
args = append(args, limit)
rows, err := s.db.QueryContext(ctx, `
select guild_id, user_id, username, coalesce(global_name, ''), coalesce(display_name, ''),
coalesce(nick, ''), coalesce(discriminator, ''), coalesce(avatar, ''),
role_ids_json, bot, coalesce(joined_at, ''), raw_json
from members
where `+strings.Join(clauses, " and ")+`
order by coalesce(nullif(display_name, ''), nullif(nick, ''), nullif(global_name, ''), username), username
limit ?
`, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
return scanMemberRows(rows)
}
func (s *Store) MemberByID(ctx context.Context, userID string) ([]MemberRow, error) {
rows, err := s.db.QueryContext(ctx, `
select guild_id, user_id, username, coalesce(global_name, ''), coalesce(display_name, ''),
coalesce(nick, ''), coalesce(discriminator, ''), coalesce(avatar, ''),
role_ids_json, bot, coalesce(joined_at, ''), raw_json
from members
where user_id = ?
order by guild_id, username
`, userID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
return scanMemberRows(rows)
}
func (s *Store) Channels(ctx context.Context, guildID string) ([]ChannelRow, error) {
args := []any{}
query := `
select id, guild_id, coalesce(parent_id, ''), kind, name, coalesce(topic, ''), position,
is_nsfw, is_archived, is_locked, is_private_thread, coalesce(thread_parent_id, ''), coalesce(archive_timestamp, '')
from channels
`
if guildID != "" {
query += ` where guild_id = ?`
args = append(args, guildID)
}
query += ` order by guild_id, position, name`
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
var out []ChannelRow
for rows.Next() {
var row ChannelRow
var archived int
var locked int
var nsfw int
var priv int
var archiveTS string
if err := rows.Scan(&row.ID, &row.GuildID, &row.ParentID, &row.Kind, &row.Name, &row.Topic, &row.Position, &nsfw, &archived, &locked, &priv, &row.ThreadParentID, &archiveTS); err != nil {
return nil, err
}
row.IsNSFW = nsfw == 1
row.IsArchived = archived == 1
row.IsLocked = locked == 1
row.IsPrivateThread = priv == 1
row.ArchiveTimestamp = parseTime(archiveTS)
out = append(out, row)
}
return out, rows.Err()
}
func (s *Store) GuildChannelCount(ctx context.Context, guildID string) (int, error) {
var count int
if err := s.db.QueryRowContext(ctx, `
select count(*)
from channels
where guild_id = ?
`, guildID).Scan(&count); err != nil {
return 0, err
}
return count, nil
}
func (s *Store) GuildMemberCount(ctx context.Context, guildID string) (int, error) {
var count int
if err := s.db.QueryRowContext(ctx, `
select count(*)
from members
where guild_id = ?
`, guildID).Scan(&count); err != nil {
return 0, err
}
return count, nil
}
func (s *Store) IncompleteMessageChannelIDs(ctx context.Context, guildID string) ([]string, error) {
args := []any{}
query := `
select c.id
from channels c
where c.kind in ('text', 'news', 'announcement', 'thread_public', 'thread_private', 'thread_news', 'thread_announcement')
`
if guildID != "" {
query += ` and c.guild_id = ?`
args = append(args, guildID)
}
query += `
and not exists (
select 1
from sync_state s
where s.scope = 'channel:' || c.id || ':history_complete'
)
and not exists (
select 1
from sync_state s
where s.scope = 'channel:' || c.id || ':unavailable'
)
order by c.id
`
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
var out []string
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return nil, err
}
out = append(out, id)
}
return out, rows.Err()
}
func (s *Store) Status(ctx context.Context, dbPath, defaultGuildID string) (Status, error) {
status := Status{DBPath: dbPath, DefaultGuildID: defaultGuildID}
queries := map[string]*int{
`select count(*) from guilds`: &status.GuildCount,
`select count(*) from channels`: &status.ChannelCount,
`select count(*) from messages`: &status.MessageCount,
`select count(*) from members`: &status.MemberCount,
`select count(*) from embedding_jobs where state = 'pending'`: &status.EmbeddingBacklog,
}
for query, target := range queries {
if err := s.db.QueryRowContext(ctx, query).Scan(target); err != nil {
return Status{}, err
}
}
if err := s.db.QueryRowContext(ctx, `select count(*) from channels where kind like 'thread_%'`).Scan(&status.ThreadCount); err != nil {
return Status{}, err
}
var lastSync string
_ = s.db.QueryRowContext(ctx, `select updated_at from sync_state where scope = 'sync:last_success'`).Scan(&lastSync)
status.LastSyncAt = parseTime(lastSync)
var lastTail string
_ = s.db.QueryRowContext(ctx, `select updated_at from sync_state where scope = 'tail:last_event'`).Scan(&lastTail)
status.LastTailEventAt = parseTime(lastTail)
if defaultGuildID != "" {
_ = s.db.QueryRowContext(ctx, `select name from guilds where id = ?`, defaultGuildID).Scan(&status.DefaultGuildName)
}
rows, err := s.db.QueryContext(ctx, `select id from guilds order by id`)
if err != nil {
return Status{}, err
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var guildID string
if err := rows.Scan(&guildID); err != nil {
return Status{}, err
}
status.AccessibleGuildIDs = append(status.AccessibleGuildIDs, guildID)
}
return status, rows.Err()
}
func (s *Store) ReadOnlyQuery(ctx context.Context, query string) ([]string, [][]string, error) {
query = strings.TrimSpace(query)
if query == "" {
return nil, nil, errors.New("empty query")
}
if !IsReadOnlySQL(query) {
return nil, nil, errors.New("only read-only sql is allowed")
}
db, closeFn, err := s.openReadOnlyDB()
if err != nil {
return nil, nil, err
}
if closeFn != nil {
defer closeFn()
}
return queryRows(ctx, db, query)
}
func (s *Store) Query(ctx context.Context, query string) ([]string, [][]string, error) {
query = strings.TrimSpace(query)
if query == "" {
return nil, nil, errors.New("empty query")
}
return queryRows(ctx, s.db, query)
}
func (s *Store) Exec(ctx context.Context, query string) (int64, error) {
query = strings.TrimSpace(query)
if query == "" {
return 0, errors.New("empty query")
}
queryCtx, cancel := withQueryTimeout(ctx)
defer cancel()
result, err := s.db.ExecContext(queryCtx, query)
if err != nil {
return 0, err
}
affected, err := result.RowsAffected()
if err != nil {
return 0, err
}
return affected, nil
}
func queryRows(ctx context.Context, db *sql.DB, query string) ([]string, [][]string, error) {
queryCtx, cancel := withQueryTimeout(ctx)
defer cancel()
rows, err := db.QueryContext(queryCtx, query)
if err != nil {
return nil, nil, err
}
defer func() { _ = rows.Close() }()
cols, err := rows.Columns()
if err != nil {
return nil, nil, err
}
if len(cols) == 0 {
return nil, nil, errors.New("query returned no columns")
}
var out [][]string
for rows.Next() {
if len(out) >= queryRowLimit {
return nil, nil, fmt.Errorf("query returned more than %d rows", queryRowLimit)
}
values := make([]any, len(cols))
ptrs := make([]any, len(cols))
for i := range values {
ptrs[i] = &values[i]
}
if err := rows.Scan(ptrs...); err != nil {
return nil, nil, err
}
record := make([]string, len(cols))
for i, value := range values {
record[i] = stringify(value)
}
out = append(out, record)
}
return cols, out, rows.Err()
}
func (s *Store) openReadOnlyDB() (*sql.DB, func(), error) {
if strings.TrimSpace(s.path) == "" {
return s.db, nil, nil
}
if _, err := os.Stat(s.path); err != nil {
return nil, nil, err
}
dsn := fmt.Sprintf(
"file:%s?mode=ro&_pragma=query_only(1)&_pragma=busy_timeout(5000)&_pragma=temp_store(MEMORY)&_pragma=mmap_size(268435456)",
s.path,
)
db, err := sql.Open("sqlite", dsn)
if err != nil {
return nil, nil, err
}
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
return db, func() { _ = db.Close() }, nil
}
func withQueryTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
if _, ok := ctx.Deadline(); ok {
return context.WithCancel(ctx)
}
return context.WithTimeout(ctx, queryTimeout)
}
func searchCandidateLimit(limit int) int {
if limit <= 0 {
limit = 20
}
candidates := limit * searchCandidateMultiple
if candidates < searchCandidateFloor {
return searchCandidateFloor
}
if candidates > searchCandidateCap {
return searchCandidateCap
}
return candidates
}
func IsReadOnlySQL(query string) bool {
switch leadingSQLKeyword(query) {
case "select", "with", "explain", "pragma":
return true
default:
return false
}
}
func leadingSQLKeyword(query string) string {
trimmed := strings.TrimSpace(query)
for trimmed != "" {
switch {
case strings.HasPrefix(trimmed, "--"):
if idx := strings.IndexByte(trimmed, '\n'); idx >= 0 {
trimmed = strings.TrimSpace(trimmed[idx+1:])
continue
}
return ""
case strings.HasPrefix(trimmed, "/*"):
end := strings.Index(trimmed, "*/")
if end < 0 {
return ""
}
trimmed = strings.TrimSpace(trimmed[end+2:])
default:
fields := strings.Fields(trimmed)
if len(fields) == 0 {
return ""
}
return strings.ToLower(fields[0])
}
}
return ""
}
func placeholders(n int) string {
if n <= 0 {
return ""
}
parts := make([]string, n)
for i := range parts {
parts[i] = "?"
}
return strings.Join(parts, ", ")
}
func stringify(value any) string {
switch v := value.(type) {
case nil:
return ""
case []byte:
return string(v)
case string:
return v
case time.Time:
return v.Format(timeLayout)
default:
return fmt.Sprint(v)
}
}
func normalizeFTSQuery(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return raw
}
fields := strings.Fields(raw)
for i, field := range fields {
fields[i] = `"` + strings.ReplaceAll(field, `"`, " ") + `"`
}
return strings.Join(fields, " ")
}
func parseTime(value string) time.Time {
if value == "" {
return time.Time{}
}
t, err := time.Parse(timeLayout, value)
if err == nil {
return t
}
t, _ = time.Parse(time.RFC3339, value)
return t
}