From 40c787c54a73b98e923cf696bde72b8b335d0ede Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Fri, 8 May 2026 09:58:34 +0100 Subject: [PATCH] refactor: consume crawlkit embedding primitives --- CHANGELOG.md | 1 + go.mod | 2 +- go.sum | 4 +- internal/cli/admin_commands.go | 6 +- internal/cli/cli.go | 13 +- internal/cli/query_commands.go | 4 +- internal/embed/ollama.go | 91 ------- internal/embed/openai_compatible.go | 82 ------ internal/embed/provider.go | 310 ---------------------- internal/embed/provider_test.go | 387 ---------------------------- internal/store/embeddings.go | 28 +- internal/store/query.go | 74 ++---- internal/store/store_write_test.go | 3 +- 13 files changed, 59 insertions(+), 946 deletions(-) delete mode 100644 internal/embed/ollama.go delete mode 100644 internal/embed/openai_compatible.go delete mode 100644 internal/embed/provider.go delete mode 100644 internal/embed/provider_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index f51bafa..c6525c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ ### Maintenance - Migrated runtime paths, SQLite opening, archive mirror/export/import helpers, output/status wiring, and TUI plumbing onto the shared `crawlkit` infrastructure. +- Moved reusable embedding providers and vector helpers onto `crawlkit` while keeping Discrawl-owned storage, FTS, queueing, and privacy filters local. - Updated crawlkit through `v0.4.1`, switched imports to `github.com/openclaw/crawlkit`, and added CI smoke coverage for the crawlkit control surface and merge behavior. - Added CodeQL, verified secret scanning, protected automation owners, stale issue automation, `.editorconfig`, and `.gitattributes`. - Added release workflow automation that dispatches the Homebrew tap formula update after GoReleaser publishes a tag. diff --git a/go.mod b/go.mod index ea54c2b..7d4ac4f 100644 --- a/go.mod +++ b/go.mod @@ -43,7 +43,7 @@ require ( github.com/muesli/cancelreader v0.2.2 // indirect github.com/muesli/termenv v0.16.0 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect - github.com/openclaw/crawlkit v0.4.2 + github.com/openclaw/crawlkit v0.5.0 github.com/pmezard/go-difflib v1.0.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rivo/uniseg v0.4.7 // indirect diff --git a/go.sum b/go.sum index 2f01c5a..114e7c2 100644 --- a/go.sum +++ b/go.sum @@ -63,8 +63,8 @@ github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= -github.com/openclaw/crawlkit v0.4.2 h1:Lzzkd2/xSkQk+7KyboMEw+ZS2wmlYvDFLwAB2Z/FwBs= -github.com/openclaw/crawlkit v0.4.2/go.mod h1:/AI8o/DeRqXPZJPHq/9mGUjNzLPskm/wTjikRPxEdHY= +github.com/openclaw/crawlkit v0.5.0 h1:sVqIbQ5v6LiOf+NXcVj93UhfoaJqMbBlrd1lU6uhO9M= +github.com/openclaw/crawlkit v0.5.0/go.mod h1:/AI8o/DeRqXPZJPHq/9mGUjNzLPskm/wTjikRPxEdHY= github.com/pelletier/go-toml/v2 v2.3.1 h1:MYEvvGnQjeNkRF1qUuGolNtNExTDwct51yp7olPtrEc= github.com/pelletier/go-toml/v2 v2.3.1/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= diff --git a/internal/cli/admin_commands.go b/internal/cli/admin_commands.go index e6b0871..8b31975 100644 --- a/internal/cli/admin_commands.go +++ b/internal/cli/admin_commands.go @@ -13,10 +13,10 @@ import ( "syscall" "time" + "github.com/openclaw/crawlkit/embed" "github.com/openclaw/discrawl/internal/config" "github.com/openclaw/discrawl/internal/discord" "github.com/openclaw/discrawl/internal/discorddesktop" - "github.com/openclaw/discrawl/internal/embed" "github.com/openclaw/discrawl/internal/share" "github.com/openclaw/discrawl/internal/store" "github.com/openclaw/discrawl/internal/syncer" @@ -374,7 +374,7 @@ func (r *runtime) runEmbed(args []string) error { providerFactory := r.newEmbed if providerFactory == nil { providerFactory = func(cfg config.EmbeddingsConfig) (embed.Provider, error) { - return embed.NewProvider(cfg) + return embed.NewProvider(crawlkitEmbeddingConfig(cfg)) } } provider, err := providerFactory(r.cfg.Search.Embeddings) @@ -435,7 +435,7 @@ func (r *runtime) runDoctor(args []string) error { report["share_stale_after"] = cfg.Share.StaleAfter } if cfg.Search.Embeddings.Enabled { - check := embed.CheckProvider(r.ctx, cfg.Search.Embeddings) + check := embed.CheckProvider(r.ctx, crawlkitEmbeddingConfig(cfg.Search.Embeddings)) report["embeddings"] = check.Status report["embeddings_provider"] = check.Provider report["embeddings_model"] = check.Model diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 66ac63b..5bd9228 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -11,9 +11,9 @@ import ( "time" "github.com/bwmarrin/discordgo" + "github.com/openclaw/crawlkit/embed" "github.com/openclaw/discrawl/internal/config" "github.com/openclaw/discrawl/internal/discord" - "github.com/openclaw/discrawl/internal/embed" "github.com/openclaw/discrawl/internal/share" "github.com/openclaw/discrawl/internal/store" "github.com/openclaw/discrawl/internal/syncer" @@ -118,6 +118,17 @@ type runtime struct { now func() time.Time } +func crawlkitEmbeddingConfig(cfg config.EmbeddingsConfig) embed.Config { + return embed.Config{ + Provider: cfg.Provider, + Model: cfg.Model, + BaseURL: cfg.BaseURL, + APIKeyEnv: cfg.APIKeyEnv, + RequestTimeout: cfg.RequestTimeout, + MaxInputChars: cfg.MaxInputChars, + } +} + type discordClient interface { syncer.Client Close() error diff --git a/internal/cli/query_commands.go b/internal/cli/query_commands.go index 1fc698f..54dca46 100644 --- a/internal/cli/query_commands.go +++ b/internal/cli/query_commands.go @@ -9,8 +9,8 @@ import ( "os" "strings" + "github.com/openclaw/crawlkit/embed" "github.com/openclaw/discrawl/internal/config" - "github.com/openclaw/discrawl/internal/embed" "github.com/openclaw/discrawl/internal/store" ) @@ -112,7 +112,7 @@ func (r *runtime) semanticSearchOptions(opts store.SearchOptions) (store.Semanti providerFactory := r.newEmbed if providerFactory == nil { providerFactory = func(cfg config.EmbeddingsConfig) (embed.Provider, error) { - return embed.NewProvider(cfg) + return embed.NewProvider(crawlkitEmbeddingConfig(cfg)) } } provider, err := providerFactory(r.cfg.Search.Embeddings) diff --git a/internal/embed/ollama.go b/internal/embed/ollama.go deleted file mode 100644 index b5daa15..0000000 --- a/internal/embed/ollama.go +++ /dev/null @@ -1,91 +0,0 @@ -package embed - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" -) - -type ollamaProvider struct { - client *http.Client - baseURL string - model string - maxInputChars int -} - -type ollamaEmbedRequest struct { - Model string `json:"model"` - Input []string `json:"input"` -} - -type ollamaEmbedResponse struct { - Model string `json:"model"` - Embeddings [][]float32 `json:"embeddings"` -} - -func newOllamaProvider(settings providerSettings) Provider { - return &ollamaProvider{ - client: settings.HTTPClient, - baseURL: settings.BaseURL, - model: settings.Model, - maxInputChars: settings.MaxInputChars, - } -} - -func (p *ollamaProvider) Embed(ctx context.Context, inputs []string) (EmbeddingBatch, error) { - if len(inputs) == 0 { - return EmbeddingBatch{Model: p.model}, nil - } - payload := ollamaEmbedRequest{ - Model: p.model, - Input: trimInputs(inputs, p.maxInputChars), - } - var response ollamaEmbedResponse - if err := postJSON(ctx, p.client, p.baseURL+"/api/embed", "", payload, &response); err != nil { - return EmbeddingBatch{}, err - } - if len(response.Embeddings) != len(inputs) { - return EmbeddingBatch{}, fmt.Errorf("ollama embedding response returned %d vectors for %d inputs", len(response.Embeddings), len(inputs)) - } - dimensions, err := inferDimensions(response.Embeddings) - if err != nil { - return EmbeddingBatch{}, err - } - model := response.Model - if model == "" { - model = p.model - } - return EmbeddingBatch{Model: model, Dimensions: dimensions, Vectors: response.Embeddings}, nil -} - -func postJSON(ctx context.Context, client *http.Client, endpoint, apiKey string, payload any, target any) error { - body, err := json.Marshal(payload) - if err != nil { - return fmt.Errorf("marshal embedding request: %w", err) - } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) - if err != nil { - return fmt.Errorf("build embedding request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json") - if apiKey != "" { - req.Header.Set("Authorization", "Bearer "+apiKey) - } - resp, err := client.Do(req) - if err != nil { - return fmt.Errorf("embedding request failed: %w", err) - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - msg, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) - return &HTTPError{StatusCode: resp.StatusCode, Body: string(msg)} - } - if err := json.NewDecoder(resp.Body).Decode(target); err != nil { - return fmt.Errorf("decode embedding response: %w", err) - } - return nil -} diff --git a/internal/embed/openai_compatible.go b/internal/embed/openai_compatible.go deleted file mode 100644 index 4c65a5d..0000000 --- a/internal/embed/openai_compatible.go +++ /dev/null @@ -1,82 +0,0 @@ -package embed - -import ( - "context" - "fmt" - "net/http" -) - -type openAICompatibleProvider struct { - client *http.Client - baseURL string - apiKey string - model string - maxInputChars int -} - -type openAIEmbeddingRequest struct { - Model string `json:"model"` - Input []string `json:"input"` -} - -type openAIEmbeddingResponse struct { - Model string `json:"model"` - Data []openAIEmbeddingItem `json:"data"` -} - -type openAIEmbeddingItem struct { - Index *int `json:"index"` - Embedding []float32 `json:"embedding"` -} - -func newOpenAICompatibleProvider(settings providerSettings) Provider { - return &openAICompatibleProvider{ - client: settings.HTTPClient, - baseURL: settings.BaseURL, - apiKey: settings.APIKey, - model: settings.Model, - maxInputChars: settings.MaxInputChars, - } -} - -func (p *openAICompatibleProvider) Embed(ctx context.Context, inputs []string) (EmbeddingBatch, error) { - if len(inputs) == 0 { - return EmbeddingBatch{Model: p.model}, nil - } - payload := openAIEmbeddingRequest{ - Model: p.model, - Input: trimInputs(inputs, p.maxInputChars), - } - var response openAIEmbeddingResponse - if err := postJSON(ctx, p.client, p.baseURL+"/embeddings", p.apiKey, payload, &response); err != nil { - return EmbeddingBatch{}, err - } - if len(response.Data) != len(inputs) { - return EmbeddingBatch{}, fmt.Errorf("openai-compatible embedding response returned %d vectors for %d inputs", len(response.Data), len(inputs)) - } - vectors := make([][]float32, len(inputs)) - seen := make([]bool, len(inputs)) - for position, item := range response.Data { - index := position - if item.Index != nil { - index = *item.Index - } - if index < 0 || index >= len(inputs) { - return EmbeddingBatch{}, fmt.Errorf("openai-compatible embedding response index %d out of range", index) - } - if seen[index] { - return EmbeddingBatch{}, fmt.Errorf("openai-compatible embedding response duplicated index %d", index) - } - seen[index] = true - vectors[index] = item.Embedding - } - dimensions, err := inferDimensions(vectors) - if err != nil { - return EmbeddingBatch{}, err - } - model := response.Model - if model == "" { - model = p.model - } - return EmbeddingBatch{Model: model, Dimensions: dimensions, Vectors: vectors}, nil -} diff --git a/internal/embed/provider.go b/internal/embed/provider.go deleted file mode 100644 index 4396ae9..0000000 --- a/internal/embed/provider.go +++ /dev/null @@ -1,310 +0,0 @@ -package embed - -import ( - "context" - "errors" - "fmt" - "net" - "net/http" - "net/url" - "os" - "strings" - "time" - - "github.com/openclaw/discrawl/internal/config" -) - -const ( - ProviderOpenAI = "openai" - ProviderOllama = "ollama" - ProviderLlamaCpp = "llamacpp" - ProviderOpenAICompatible = "openai_compatible" - DefaultOpenAIBaseURL = "https://api.openai.com/v1" - DefaultOllamaBaseURL = "http://127.0.0.1:11434" - DefaultLlamaCppBaseURL = "http://127.0.0.1:8080/v1" - DefaultOpenAIModel = "text-embedding-3-small" - DefaultLocalEmbeddingModel = "nomic-embed-text" - DefaultBatchSize = 64 - DefaultMaxInputChars = 12000 - DefaultRequestTimeout = 2 * time.Minute - DefaultProbeTimeout = 2 * time.Second -) - -// Provider is the narrow embedding surface used by later queue/search work. -type Provider interface { - Embed(ctx context.Context, inputs []string) (EmbeddingBatch, error) -} - -type EmbeddingBatch struct { - Model string - Dimensions int - Vectors [][]float32 -} - -type HTTPError struct { - StatusCode int - Body string -} - -func (e *HTTPError) Error() string { - return fmt.Sprintf("embedding request failed with HTTP %d: %s", e.StatusCode, e.Body) -} - -func IsRateLimitError(err error) bool { - var httpErr *HTTPError - return errors.As(err, &httpErr) && httpErr.StatusCode == http.StatusTooManyRequests -} - -type CheckResult struct { - Provider string - Model string - BaseURL string - Status string - Warning string - Probed bool -} - -type Option func(*providerOptions) - -type providerOptions struct { - httpClient *http.Client - timeoutOverride time.Duration -} - -type providerSettings struct { - Name string - Model string - BaseURL string - APIKey string - MaxInputChars int - Timeout time.Duration - HTTPClient *http.Client -} - -func WithHTTPClient(client *http.Client) Option { - return func(opts *providerOptions) { - opts.httpClient = client - } -} - -func WithRequestTimeout(timeout time.Duration) Option { - return func(opts *providerOptions) { - opts.timeoutOverride = timeout - } -} - -func NewProvider(cfg config.EmbeddingsConfig, opts ...Option) (Provider, error) { - settings, err := resolveProviderConfig(cfg, true, opts...) - if err != nil { - return nil, err - } - return newProvider(settings) -} - -func CheckProvider(ctx context.Context, cfg config.EmbeddingsConfig) CheckResult { - settings, err := resolveProviderConfig(cfg, true, WithRequestTimeout(DefaultProbeTimeout)) - if err != nil { - return CheckResult{ - Provider: normalizedProviderName(cfg.Provider), - Model: strings.TrimSpace(cfg.Model), - BaseURL: strings.TrimSpace(cfg.BaseURL), - Status: "warning", - Warning: err.Error(), - } - } - result := CheckResult{ - Provider: settings.Name, - Model: settings.Model, - BaseURL: settings.BaseURL, - Status: "ok", - } - if !shouldProbe(settings) { - return result - } - provider, err := newProvider(settings) - if err != nil { - result.Status = "warning" - result.Warning = err.Error() - return result - } - probeCtx, cancel := context.WithTimeout(ctx, DefaultProbeTimeout) - defer cancel() - if _, err := provider.Embed(probeCtx, []string{"discrawl probe"}); err != nil { - result.Status = "warning" - result.Warning = err.Error() - return result - } - result.Probed = true - return result -} - -func resolveProviderConfig(cfg config.EmbeddingsConfig, validateAPIKey bool, opts ...Option) (providerSettings, error) { - options := providerOptions{} - for _, opt := range opts { - opt(&options) - } - name := normalizedProviderName(cfg.Provider) - if name == "" { - name = ProviderOpenAI - } - model := strings.TrimSpace(cfg.Model) - if model == "" { - model = defaultModel(name) - } - baseURL := strings.TrimRight(strings.TrimSpace(cfg.BaseURL), "/") - if baseURL == "" { - switch name { - case ProviderOpenAI: - baseURL = DefaultOpenAIBaseURL - case ProviderOllama: - baseURL = DefaultOllamaBaseURL - case ProviderLlamaCpp: - baseURL = DefaultLlamaCppBaseURL - case ProviderOpenAICompatible: - return providerSettings{}, fmt.Errorf("embedding provider %q requires base_url", name) - } - } - timeout := DefaultRequestTimeout - if strings.TrimSpace(cfg.RequestTimeout) != "" { - parsed, err := time.ParseDuration(cfg.RequestTimeout) - if err != nil { - return providerSettings{}, fmt.Errorf("parse embeddings request_timeout: %w", err) - } - if parsed <= 0 { - return providerSettings{}, errors.New("embeddings request_timeout must be positive") - } - timeout = parsed - } - if options.timeoutOverride > 0 && options.timeoutOverride < timeout { - timeout = options.timeoutOverride - } - maxInputChars := cfg.MaxInputChars - if maxInputChars <= 0 { - maxInputChars = DefaultMaxInputChars - } - switch name { - case ProviderOpenAI, ProviderOllama, ProviderLlamaCpp, ProviderOpenAICompatible: - default: - return providerSettings{}, fmt.Errorf("unsupported embedding provider %q", name) - } - apiKey, err := resolveAPIKey(name, cfg.APIKeyEnv, validateAPIKey) - if err != nil { - return providerSettings{}, err - } - client := options.httpClient - if client == nil { - client = &http.Client{Timeout: timeout} - } - if _, err := url.ParseRequestURI(baseURL); err != nil { - return providerSettings{}, fmt.Errorf("invalid embeddings base_url %q: %w", baseURL, err) - } - return providerSettings{ - Name: name, - Model: model, - BaseURL: baseURL, - APIKey: apiKey, - MaxInputChars: maxInputChars, - Timeout: timeout, - HTTPClient: client, - }, nil -} - -func newProvider(settings providerSettings) (Provider, error) { - switch settings.Name { - case ProviderOllama: - return newOllamaProvider(settings), nil - case ProviderOpenAI, ProviderLlamaCpp, ProviderOpenAICompatible: - return newOpenAICompatibleProvider(settings), nil - default: - return nil, fmt.Errorf("unsupported embedding provider %q", settings.Name) - } -} - -func resolveAPIKey(provider, apiKeyEnv string, validate bool) (string, error) { - envName := strings.TrimSpace(apiKeyEnv) - required := provider == ProviderOpenAI - if envName == "" { - if required { - envName = "OPENAI_API_KEY" - } else { - return "", nil - } - } - value := strings.TrimSpace(os.Getenv(envName)) - if value == "" { - if required || validate { - return "", fmt.Errorf("embedding provider %q requires API key env %s", provider, envName) - } - return "", nil - } - return value, nil -} - -func normalizedProviderName(provider string) string { - return strings.ToLower(strings.TrimSpace(provider)) -} - -func defaultModel(provider string) string { - switch provider { - case ProviderOllama, ProviderLlamaCpp: - return DefaultLocalEmbeddingModel - default: - return DefaultOpenAIModel - } -} - -func shouldProbe(settings providerSettings) bool { - switch settings.Name { - case ProviderOllama, ProviderLlamaCpp: - return true - case ProviderOpenAICompatible: - return isLoopbackBaseURL(settings.BaseURL) - default: - return false - } -} - -func isLoopbackBaseURL(rawURL string) bool { - parsed, err := url.Parse(rawURL) - if err != nil { - return false - } - host := parsed.Hostname() - if host == "localhost" { - return true - } - ip := net.ParseIP(host) - return ip != nil && ip.IsLoopback() -} - -func trimInputs(inputs []string, maxChars int) []string { - if maxChars <= 0 { - maxChars = DefaultMaxInputChars - } - out := make([]string, len(inputs)) - for i, input := range inputs { - runes := []rune(input) - if len(runes) > maxChars { - runes = runes[:maxChars] - } - out[i] = string(runes) - } - return out -} - -func inferDimensions(vectors [][]float32) (int, error) { - dimensions := 0 - for _, vector := range vectors { - if len(vector) == 0 { - return 0, errors.New("embedding response contained an empty vector") - } - if dimensions == 0 { - dimensions = len(vector) - continue - } - if len(vector) != dimensions { - return 0, fmt.Errorf("embedding response dimensions mismatch: got %d want %d", len(vector), dimensions) - } - } - return dimensions, nil -} diff --git a/internal/embed/provider_test.go b/internal/embed/provider_test.go deleted file mode 100644 index e22c9ee..0000000 --- a/internal/embed/provider_test.go +++ /dev/null @@ -1,387 +0,0 @@ -package embed - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/openclaw/discrawl/internal/config" -) - -func TestOllamaProviderEmbeds(t *testing.T) { - t.Parallel() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "/api/embed", r.URL.Path) - assert.Equal(t, http.MethodPost, r.Method) - var req ollamaEmbedRequest - assert.NoError(t, json.NewDecoder(r.Body).Decode(&req)) - assert.Equal(t, "nomic-embed-text", req.Model) - assert.Equal(t, []string{"abcd", "xy"}, req.Input) - _, _ = w.Write([]byte(`{"model":"nomic-embed-text","embeddings":[[1,2,3],[4,5,6]]}`)) - })) - defer server.Close() - - provider, err := NewProvider(config.EmbeddingsConfig{ - Provider: ProviderOllama, - Model: "nomic-embed-text", - BaseURL: server.URL, - MaxInputChars: 4, - RequestTimeout: "5s", - }) - require.NoError(t, err) - - batch, err := provider.Embed(context.Background(), []string{"abcdef", "xy"}) - require.NoError(t, err) - require.Equal(t, "nomic-embed-text", batch.Model) - require.Equal(t, 3, batch.Dimensions) - require.Equal(t, [][]float32{{1, 2, 3}, {4, 5, 6}}, batch.Vectors) -} - -func TestOpenAICompatibleProviderEmbedsAndUsesAuth(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "/embeddings", r.URL.Path) - assert.Equal(t, "Bearer secret", r.Header.Get("Authorization")) - var req openAIEmbeddingRequest - assert.NoError(t, json.NewDecoder(r.Body).Decode(&req)) - assert.Equal(t, "local-model", req.Model) - assert.Equal(t, []string{"one", "two"}, req.Input) - _, _ = w.Write([]byte(`{ - "model":"local-model", - "data":[ - {"index":1,"embedding":[3,4]}, - {"index":0,"embedding":[1,2]} - ] - }`)) - })) - defer server.Close() - t.Setenv("DISCRAWL_EMBED_KEY", "secret") - - provider, err := NewProvider(config.EmbeddingsConfig{ - Provider: ProviderOpenAICompatible, - Model: "local-model", - BaseURL: server.URL, - APIKeyEnv: "DISCRAWL_EMBED_KEY", - RequestTimeout: "5s", - }) - require.NoError(t, err) - - batch, err := provider.Embed(context.Background(), []string{"one", "two"}) - require.NoError(t, err) - require.Equal(t, "local-model", batch.Model) - require.Equal(t, 2, batch.Dimensions) - require.Equal(t, [][]float32{{1, 2}, {3, 4}}, batch.Vectors) -} - -func TestProviderFactoryDefaultsAndValidation(t *testing.T) { - t.Setenv("OPENAI_API_KEY", "openai-secret") - - openAI, err := resolveProviderConfig(config.EmbeddingsConfig{ - Provider: ProviderOpenAI, - RequestTimeout: "5s", - }, true) - require.NoError(t, err) - require.Equal(t, DefaultOpenAIBaseURL, openAI.BaseURL) - require.Equal(t, DefaultOpenAIModel, openAI.Model) - require.Equal(t, "openai-secret", openAI.APIKey) - - ollama, err := resolveProviderConfig(config.EmbeddingsConfig{ - Provider: ProviderOllama, - RequestTimeout: "5s", - }, true) - require.NoError(t, err) - require.Equal(t, DefaultOllamaBaseURL, ollama.BaseURL) - require.Equal(t, DefaultLocalEmbeddingModel, ollama.Model) - - llamaCpp, err := resolveProviderConfig(config.EmbeddingsConfig{ - Provider: ProviderLlamaCpp, - RequestTimeout: "5s", - }, true) - require.NoError(t, err) - require.Equal(t, DefaultLlamaCppBaseURL, llamaCpp.BaseURL) - - _, err = resolveProviderConfig(config.EmbeddingsConfig{ - Provider: ProviderOpenAICompatible, - RequestTimeout: "5s", - }, true) - require.ErrorContains(t, err, "requires base_url") -} - -func TestProviderFactoryRequiresOpenAIAPIKey(t *testing.T) { - t.Setenv("OPENAI_API_KEY", "") - - _, err := NewProvider(config.EmbeddingsConfig{ - Provider: ProviderOpenAI, - RequestTimeout: "5s", - }) - require.ErrorContains(t, err, "requires API key env OPENAI_API_KEY") -} - -func TestProviderFactoryReportsUnsupportedProviderBeforeAPIKey(t *testing.T) { - t.Setenv("MISSING_EMBED_KEY", "") - - _, err := NewProvider(config.EmbeddingsConfig{ - Provider: "bogus", - APIKeyEnv: "MISSING_EMBED_KEY", - RequestTimeout: "5s", - }) - require.ErrorContains(t, err, "unsupported embedding provider \"bogus\"") -} - -func TestCheckProviderProbesLocalProvider(t *testing.T) { - t.Parallel() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "/api/embed", r.URL.Path) - _, _ = w.Write([]byte(`{"model":"nomic-embed-text","embeddings":[[1,2]]}`)) - })) - defer server.Close() - - result := CheckProvider(context.Background(), config.EmbeddingsConfig{ - Provider: ProviderOllama, - Model: "nomic-embed-text", - BaseURL: server.URL, - RequestTimeout: "5s", - }) - require.Equal(t, "ok", result.Status) - require.True(t, result.Probed) - require.Empty(t, result.Warning) - require.Equal(t, server.URL, result.BaseURL) -} - -func TestCheckProviderWarnsOnLocalProbeFailure(t *testing.T) { - t.Parallel() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.Error(w, "not ready", http.StatusServiceUnavailable) - })) - defer server.Close() - - result := CheckProvider(context.Background(), config.EmbeddingsConfig{ - Provider: ProviderOllama, - Model: "nomic-embed-text", - BaseURL: server.URL, - RequestTimeout: "5s", - }) - require.Equal(t, "warning", result.Status) - require.Contains(t, result.Warning, "HTTP 503") - require.False(t, result.Probed) -} - -func TestProviderExposesRateLimitErrors(t *testing.T) { - t.Parallel() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.Error(w, "rate limited", http.StatusTooManyRequests) - })) - defer server.Close() - - provider, err := NewProvider(config.EmbeddingsConfig{ - Provider: ProviderOpenAICompatible, - Model: "local-model", - BaseURL: server.URL, - RequestTimeout: "5s", - }) - require.NoError(t, err) - - _, err = provider.Embed(context.Background(), []string{"one"}) - require.ErrorContains(t, err, "HTTP 429") - require.True(t, IsRateLimitError(err)) -} - -func TestProviderRejectsInvalidResponses(t *testing.T) { - t.Parallel() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = w.Write([]byte(`{"data":[{"index":0,"embedding":[1]},{"index":1,"embedding":[2,3]}]}`)) - })) - defer server.Close() - - provider, err := NewProvider(config.EmbeddingsConfig{ - Provider: ProviderOpenAICompatible, - Model: "local-model", - BaseURL: server.URL, - RequestTimeout: "5s", - }) - require.NoError(t, err) - - _, err = provider.Embed(context.Background(), []string{"one", "two"}) - require.ErrorContains(t, err, "dimensions mismatch") -} - -func TestEmbeddingProvidersHandleEmptyInputsAndIndexErrors(t *testing.T) { - t.Parallel() - - settings := providerSettings{ - Name: ProviderOllama, - Model: "model", - BaseURL: "http://127.0.0.1:1", - MaxInputChars: 10, - HTTPClient: http.DefaultClient, - } - ollama := newOllamaProvider(settings) - batch, err := ollama.Embed(context.Background(), nil) - require.NoError(t, err) - require.Equal(t, "model", batch.Model) - - settings.Name = ProviderOpenAICompatible - openai := newOpenAICompatibleProvider(settings) - batch, err = openai.Embed(context.Background(), nil) - require.NoError(t, err) - require.Equal(t, "model", batch.Model) - - tests := []struct { - name string - body string - inputs []string - want string - }{ - { - name: "count", - body: `{"data":[]}`, - inputs: []string{"one"}, - want: "returned 0 vectors for 1 inputs", - }, - { - name: "range", - body: `{"data":[{"index":2,"embedding":[1]}]}`, - inputs: []string{"one"}, - want: "index 2 out of range", - }, - { - name: "duplicate", - body: `{"data":[{"index":0,"embedding":[1]},{"index":0,"embedding":[2]}]}`, - inputs: []string{"one", "two"}, - want: "duplicated index 0", - }, - } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = w.Write([]byte(tc.body)) - })) - defer server.Close() - provider, err := NewProvider(config.EmbeddingsConfig{ - Provider: ProviderOpenAICompatible, - Model: "model", - BaseURL: server.URL, - RequestTimeout: "5s", - }) - require.NoError(t, err) - _, err = provider.Embed(context.Background(), tc.inputs) - require.ErrorContains(t, err, tc.want) - }) - } -} - -func TestProviderOptionsAndProbeDecisions(t *testing.T) { - t.Parallel() - - client := &http.Client{Timeout: time.Second} - settings, err := resolveProviderConfig(config.EmbeddingsConfig{ - Provider: ProviderOllama, - BaseURL: "http://127.0.0.1:11434/", - RequestTimeout: "30s", - }, true, WithHTTPClient(client), WithRequestTimeout(50*time.Millisecond)) - require.NoError(t, err) - require.Same(t, client, settings.HTTPClient) - require.Equal(t, 50*time.Millisecond, settings.Timeout) - require.Equal(t, "http://127.0.0.1:11434", settings.BaseURL) - require.True(t, shouldProbe(settings)) - - require.True(t, isLoopbackBaseURL("http://localhost:8080/v1")) - require.True(t, isLoopbackBaseURL("http://[::1]:8080/v1")) - require.False(t, isLoopbackBaseURL("https://api.example.com/v1")) - require.False(t, isLoopbackBaseURL("://bad")) - require.False(t, shouldProbe(providerSettings{Name: ProviderOpenAI})) - require.True(t, shouldProbe(providerSettings{Name: ProviderOpenAICompatible, BaseURL: "http://localhost:8080/v1"})) - require.False(t, shouldProbe(providerSettings{Name: ProviderOpenAICompatible, BaseURL: "https://api.example.com/v1"})) -} - -func TestProviderValidationEdges(t *testing.T) { - t.Parallel() - - _, err := resolveProviderConfig(config.EmbeddingsConfig{ - Provider: ProviderOllama, - RequestTimeout: "not-a-duration", - }, true) - require.ErrorContains(t, err, "parse embeddings request_timeout") - - _, err = resolveProviderConfig(config.EmbeddingsConfig{ - Provider: ProviderOllama, - RequestTimeout: "0s", - }, true) - require.ErrorContains(t, err, "must be positive") - - _, err = resolveProviderConfig(config.EmbeddingsConfig{ - Provider: ProviderOllama, - BaseURL: "://bad", - }, true) - require.ErrorContains(t, err, "invalid embeddings base_url") - - key, err := resolveAPIKey(ProviderOpenAICompatible, "MISSING_EMBED_KEY", false) - require.NoError(t, err) - require.Empty(t, key) - - _, err = newProvider(providerSettings{Name: "bogus"}) - require.ErrorContains(t, err, "unsupported embedding provider") - - require.Equal(t, []string{"abc"}, trimInputs([]string{"abc"}, 0)) - _, err = inferDimensions([][]float32{{}}) - require.ErrorContains(t, err, "empty vector") -} - -func TestOllamaProviderResponseEdges(t *testing.T) { - t.Parallel() - - countServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "/api/embed", r.URL.Path) - _, _ = w.Write([]byte(`{"embeddings":[]}`)) - })) - defer countServer.Close() - - provider := newOllamaProvider(providerSettings{ - HTTPClient: countServer.Client(), - BaseURL: countServer.URL, - Model: "fallback-model", - MaxInputChars: 10, - }) - _, err := provider.Embed(context.Background(), []string{"one"}) - require.ErrorContains(t, err, "returned 0 vectors for 1 inputs") - - modelServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "/api/embed", r.URL.Path) - _, _ = w.Write([]byte(`{"embeddings":[[1,2]]}`)) - })) - defer modelServer.Close() - - provider = newOllamaProvider(providerSettings{ - HTTPClient: modelServer.Client(), - BaseURL: modelServer.URL, - Model: "fallback-model", - MaxInputChars: 10, - }) - batch, err := provider.Embed(context.Background(), []string{"one"}) - require.NoError(t, err) - require.Equal(t, "fallback-model", batch.Model) -} - -func TestCheckProviderSkipsRemoteCompatibleProbe(t *testing.T) { - t.Parallel() - - result := CheckProvider(context.Background(), config.EmbeddingsConfig{ - Provider: ProviderOpenAICompatible, - Model: "remote-model", - BaseURL: "https://api.example.com/v1", - RequestTimeout: "5s", - }) - require.Equal(t, "ok", result.Status) - require.False(t, result.Probed) - require.Empty(t, result.Warning) -} diff --git a/internal/store/embeddings.go b/internal/store/embeddings.go index 8c487ba..888428a 100644 --- a/internal/store/embeddings.go +++ b/internal/store/embeddings.go @@ -1,15 +1,14 @@ package store import ( - "bytes" "context" - "encoding/binary" "errors" "fmt" "strings" "time" - "github.com/openclaw/discrawl/internal/embed" + "github.com/openclaw/crawlkit/embed" + "github.com/openclaw/crawlkit/vector" ) const ( @@ -476,28 +475,23 @@ func capRunes(value string, maxChars int) string { return string(runes[:maxChars]) } -func EncodeEmbeddingVector(vector []float32) ([]byte, error) { - buf := bytes.NewBuffer(make([]byte, 0, len(vector)*4)) - for _, value := range vector { - if err := binary.Write(buf, binary.LittleEndian, value); err != nil { - return nil, fmt.Errorf("encode embedding vector: %w", err) - } +func EncodeEmbeddingVector(values []float32) ([]byte, error) { + blob, err := vector.EncodeFloat32(values) + if err != nil { + return nil, fmt.Errorf("encode embedding vector: %w", err) } - return buf.Bytes(), nil + return blob, nil } func DecodeEmbeddingVector(blob []byte) ([]float32, error) { if len(blob)%4 != 0 { return nil, fmt.Errorf("embedding blob length %d is not a float32 multiple", len(blob)) } - out := make([]float32, len(blob)/4) - reader := bytes.NewReader(blob) - for i := range out { - if err := binary.Read(reader, binary.LittleEndian, &out[i]); err != nil { - return nil, fmt.Errorf("decode embedding vector: %w", err) - } + values, err := vector.DecodeFloat32(blob) + if err != nil { + return nil, fmt.Errorf("decode embedding vector: %w", err) } - return out, nil + return values, nil } func (s *Store) EmbeddingBacklog(ctx context.Context) (int, error) { diff --git a/internal/store/query.go b/internal/store/query.go index d8a533a..4d80db2 100644 --- a/internal/store/query.go +++ b/internal/store/query.go @@ -5,11 +5,12 @@ import ( "database/sql" "errors" "fmt" - "math" "os" "sort" "strings" "time" + + "github.com/openclaw/crawlkit/vector" ) const ( @@ -160,7 +161,7 @@ func (s *Store) SearchMessagesSemantic(ctx context.Context, opts SemanticSearchO if len(opts.QueryVector) != opts.Dimensions { return nil, fmt.Errorf("semantic query embedding dimensions mismatch: got %d want %d", len(opts.QueryVector), opts.Dimensions) } - queryNorm := vectorNorm(opts.QueryVector) + queryNorm := vector.Norm(opts.QueryVector) if queryNorm == 0 { return nil, errors.New("semantic query embedding returned a zero vector") } @@ -236,15 +237,18 @@ func (s *Store) SearchMessagesSemantic(ctx context.Context, opts SemanticSearchO if dimensions != opts.Dimensions { return nil, fmt.Errorf("stored embedding dimensions mismatch for message %s: got %d want %d", row.MessageID, dimensions, opts.Dimensions) } - vector, err := DecodeEmbeddingVector(blob) + storedVector, err := DecodeEmbeddingVector(blob) if err != nil { return nil, fmt.Errorf("decode embedding for message %s: %w", row.MessageID, err) } - if len(vector) != dimensions { - return nil, fmt.Errorf("stored embedding vector length mismatch for message %s: got %d want %d", row.MessageID, len(vector), dimensions) + if len(storedVector) != dimensions { + return nil, fmt.Errorf("stored embedding vector length mismatch for message %s: got %d want %d", row.MessageID, len(storedVector), dimensions) } - score, err := cosineSimilarity(opts.QueryVector, queryNorm, vector) + score, err := vector.CosineSimilarity(opts.QueryVector, queryNorm, storedVector) if err != nil { + if strings.Contains(err.Error(), "candidate vector is zero") { + return nil, fmt.Errorf("score embedding for message %s: stored embedding vector is zero", row.MessageID) + } return nil, fmt.Errorf("score embedding for message %s: %w", row.MessageID, err) } row.CreatedAt = parseTime(created) @@ -328,26 +332,23 @@ func fuseSearchResults(ftsResults, semanticResults []SearchResult, limit int) [] if limit <= 0 { limit = 20 } - entries := make(map[string]*hybridSearchEntry, len(ftsResults)+len(semanticResults)) - addResults := func(results []SearchResult, weight float64, fts bool) { - for index, result := range results { - entry := entries[result.MessageID] - if entry == nil { - entry = &hybridSearchEntry{result: result} - entries[result.MessageID] = entry - } - if fts { - entry.hasFTS = true - } - entry.score += weight / (rrfK + float64(index+1)) - } + id := func(result SearchResult) string { + return result.MessageID } - addResults(ftsResults, ftsRRFWeight, true) - addResults(semanticResults, semanticRRFWeight, false) - - merged := make([]hybridSearchEntry, 0, len(entries)) - for _, entry := range entries { - merged = append(merged, *entry) + ftsIDs := make(map[string]struct{}, len(ftsResults)) + for _, result := range ftsResults { + ftsIDs[result.MessageID] = struct{}{} + } + fused := vector.ReciprocalRankFusion( + [][]SearchResult{ftsResults, semanticResults}, + []func(SearchResult) string{id, id}, + []float64{ftsRRFWeight, semanticRRFWeight}, + rrfK, + ) + merged := make([]hybridSearchEntry, 0, len(fused)) + for _, entry := range fused { + _, hasFTS := ftsIDs[entry.Item.MessageID] + merged = append(merged, hybridSearchEntry{result: entry.Item, score: entry.Score, hasFTS: hasFTS}) } sort.SliceStable(merged, func(i, j int) bool { if merged[i].score != merged[j].score { @@ -490,29 +491,6 @@ func (s *Store) searchFallback(ctx context.Context, opts SearchOptions) ([]Searc return out, rows.Err() } -func cosineSimilarity(query []float32, queryNorm float64, vector []float32) (float64, error) { - if len(vector) != len(query) { - return 0, fmt.Errorf("dimensions mismatch: got %d want %d", len(vector), len(query)) - } - vectorNorm := vectorNorm(vector) - if vectorNorm == 0 { - return 0, errors.New("stored embedding vector is zero") - } - var dot float64 - for i := range query { - dot += float64(query[i]) * float64(vector[i]) - } - return dot / (queryNorm * vectorNorm), nil -} - -func vectorNorm(vector []float32) float64 { - var sum float64 - for _, value := range vector { - sum += float64(value) * float64(value) - } - return math.Sqrt(sum) -} - func (s *Store) Members(ctx context.Context, guildID, query string, limit int) ([]MemberRow, error) { if strings.TrimSpace(query) != "" { return s.searchMembers(ctx, guildID, query, limit) diff --git a/internal/store/store_write_test.go b/internal/store/store_write_test.go index b39ce07..ae1b5de 100644 --- a/internal/store/store_write_test.go +++ b/internal/store/store_write_test.go @@ -10,9 +10,8 @@ import ( "testing" "time" + "github.com/openclaw/crawlkit/embed" "github.com/stretchr/testify/require" - - "github.com/openclaw/discrawl/internal/embed" ) func TestUpsertMessagesBatch(t *testing.T) {