From 1cc2c6628303f013e60cdd91656eb849c3f3469f Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Fri, 8 May 2026 09:56:40 +0100 Subject: [PATCH] feat: add shared embedding and vector helpers --- CHANGELOG.md | 5 + README.md | 2 + embed/ollama.go | 91 ++++++++ embed/openai_compatible.go | 82 +++++++ embed/provider.go | 317 ++++++++++++++++++++++++++ embed/provider_test.go | 453 +++++++++++++++++++++++++++++++++++++ vector/vector.go | 142 ++++++++++++ vector/vector_test.go | 137 +++++++++++ 8 files changed, 1229 insertions(+) create mode 100644 embed/ollama.go create mode 100644 embed/openai_compatible.go create mode 100644 embed/provider.go create mode 100644 embed/provider_test.go create mode 100644 vector/vector.go create mode 100644 vector/vector_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f11b0a..5db1b64 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,11 @@ ## Unreleased +- Add reusable `embed` providers for OpenAI, OpenAI-compatible endpoints, + Ollama, and llama.cpp, including probe diagnostics and rate-limit errors. +- Add reusable `vector` helpers for float32 blobs, dimension validation, + cosine similarity, top-k sorting, and reciprocal-rank fusion. + ## v0.4.2 - 2026-05-08 - Add snapshot file fingerprints and an incremental import planner/executor so downstream apps can import changed JSONL/Gzip shards without deleting every table. diff --git a/README.md b/README.md index de15960..689d4d6 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,8 @@ See `docs/boundary.md` for the crawlkit-versus-app ownership boundary. - `snapshot`: `manifest.json` plus JSONL/Gzip table snapshot export, file fingerprints, full import, and planned incremental shard import. - `mirror`: clone/init/pull/commit/push helpers for private snapshot repos. - `state`: generic crawler cursor and freshness records. +- `embed`: reusable OpenAI-compatible, Ollama, and llama.cpp embedding providers plus local probe diagnostics. +- `vector`: float32 vector encoding, dimension validation, cosine scoring, top-k helpers, and reciprocal-rank fusion. - `output`: text/json/log output helpers. - `control`: crawl app metadata, command manifests, status payloads, and database inventory for launchers and automation. diff --git a/embed/ollama.go b/embed/ollama.go new file mode 100644 index 0000000..b5daa15 --- /dev/null +++ b/embed/ollama.go @@ -0,0 +1,91 @@ +package embed + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" +) + +type ollamaProvider struct { + client *http.Client + baseURL string + model string + maxInputChars int +} + +type ollamaEmbedRequest struct { + Model string `json:"model"` + Input []string `json:"input"` +} + +type ollamaEmbedResponse struct { + Model string `json:"model"` + Embeddings [][]float32 `json:"embeddings"` +} + +func newOllamaProvider(settings providerSettings) Provider { + return &ollamaProvider{ + client: settings.HTTPClient, + baseURL: settings.BaseURL, + model: settings.Model, + maxInputChars: settings.MaxInputChars, + } +} + +func (p *ollamaProvider) Embed(ctx context.Context, inputs []string) (EmbeddingBatch, error) { + if len(inputs) == 0 { + return EmbeddingBatch{Model: p.model}, nil + } + payload := ollamaEmbedRequest{ + Model: p.model, + Input: trimInputs(inputs, p.maxInputChars), + } + var response ollamaEmbedResponse + if err := postJSON(ctx, p.client, p.baseURL+"/api/embed", "", payload, &response); err != nil { + return EmbeddingBatch{}, err + } + if len(response.Embeddings) != len(inputs) { + return EmbeddingBatch{}, fmt.Errorf("ollama embedding response returned %d vectors for %d inputs", len(response.Embeddings), len(inputs)) + } + dimensions, err := inferDimensions(response.Embeddings) + if err != nil { + return EmbeddingBatch{}, err + } + model := response.Model + if model == "" { + model = p.model + } + return EmbeddingBatch{Model: model, Dimensions: dimensions, Vectors: response.Embeddings}, nil +} + +func postJSON(ctx context.Context, client *http.Client, endpoint, apiKey string, payload any, target any) error { + body, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("marshal embedding request: %w", err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("build embedding request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + if apiKey != "" { + req.Header.Set("Authorization", "Bearer "+apiKey) + } + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("embedding request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + msg, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + return &HTTPError{StatusCode: resp.StatusCode, Body: string(msg)} + } + if err := json.NewDecoder(resp.Body).Decode(target); err != nil { + return fmt.Errorf("decode embedding response: %w", err) + } + return nil +} diff --git a/embed/openai_compatible.go b/embed/openai_compatible.go new file mode 100644 index 0000000..4c65a5d --- /dev/null +++ b/embed/openai_compatible.go @@ -0,0 +1,82 @@ +package embed + +import ( + "context" + "fmt" + "net/http" +) + +type openAICompatibleProvider struct { + client *http.Client + baseURL string + apiKey string + model string + maxInputChars int +} + +type openAIEmbeddingRequest struct { + Model string `json:"model"` + Input []string `json:"input"` +} + +type openAIEmbeddingResponse struct { + Model string `json:"model"` + Data []openAIEmbeddingItem `json:"data"` +} + +type openAIEmbeddingItem struct { + Index *int `json:"index"` + Embedding []float32 `json:"embedding"` +} + +func newOpenAICompatibleProvider(settings providerSettings) Provider { + return &openAICompatibleProvider{ + client: settings.HTTPClient, + baseURL: settings.BaseURL, + apiKey: settings.APIKey, + model: settings.Model, + maxInputChars: settings.MaxInputChars, + } +} + +func (p *openAICompatibleProvider) Embed(ctx context.Context, inputs []string) (EmbeddingBatch, error) { + if len(inputs) == 0 { + return EmbeddingBatch{Model: p.model}, nil + } + payload := openAIEmbeddingRequest{ + Model: p.model, + Input: trimInputs(inputs, p.maxInputChars), + } + var response openAIEmbeddingResponse + if err := postJSON(ctx, p.client, p.baseURL+"/embeddings", p.apiKey, payload, &response); err != nil { + return EmbeddingBatch{}, err + } + if len(response.Data) != len(inputs) { + return EmbeddingBatch{}, fmt.Errorf("openai-compatible embedding response returned %d vectors for %d inputs", len(response.Data), len(inputs)) + } + vectors := make([][]float32, len(inputs)) + seen := make([]bool, len(inputs)) + for position, item := range response.Data { + index := position + if item.Index != nil { + index = *item.Index + } + if index < 0 || index >= len(inputs) { + return EmbeddingBatch{}, fmt.Errorf("openai-compatible embedding response index %d out of range", index) + } + if seen[index] { + return EmbeddingBatch{}, fmt.Errorf("openai-compatible embedding response duplicated index %d", index) + } + seen[index] = true + vectors[index] = item.Embedding + } + dimensions, err := inferDimensions(vectors) + if err != nil { + return EmbeddingBatch{}, err + } + model := response.Model + if model == "" { + model = p.model + } + return EmbeddingBatch{Model: model, Dimensions: dimensions, Vectors: vectors}, nil +} diff --git a/embed/provider.go b/embed/provider.go new file mode 100644 index 0000000..c27cf39 --- /dev/null +++ b/embed/provider.go @@ -0,0 +1,317 @@ +package embed + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "net/url" + "os" + "strings" + "time" +) + +const ( + ProviderOpenAI = "openai" + ProviderOllama = "ollama" + ProviderLlamaCpp = "llamacpp" + ProviderOpenAICompatible = "openai_compatible" + + DefaultOpenAIBaseURL = "https://api.openai.com/v1" + DefaultOllamaBaseURL = "http://127.0.0.1:11434" + DefaultLlamaCppBaseURL = "http://127.0.0.1:8080/v1" + DefaultOpenAIModel = "text-embedding-3-small" + DefaultLocalEmbeddingModel = "nomic-embed-text" + DefaultBatchSize = 64 + DefaultMaxInputChars = 12000 + DefaultRequestTimeout = 2 * time.Minute + DefaultProbeTimeout = 2 * time.Second +) + +type Config struct { + Provider string + Model string + BaseURL string + APIKeyEnv string + RequestTimeout string + MaxInputChars int +} + +type Provider interface { + Embed(ctx context.Context, inputs []string) (EmbeddingBatch, error) +} + +type EmbeddingBatch struct { + Model string + Dimensions int + Vectors [][]float32 +} + +type HTTPError struct { + StatusCode int + Body string +} + +func (e *HTTPError) Error() string { + return fmt.Sprintf("embedding request failed with HTTP %d: %s", e.StatusCode, e.Body) +} + +func IsRateLimitError(err error) bool { + var httpErr *HTTPError + return errors.As(err, &httpErr) && httpErr.StatusCode == http.StatusTooManyRequests +} + +type CheckResult struct { + Provider string + Model string + BaseURL string + Status string + Warning string + Probed bool +} + +type Option func(*providerOptions) + +type providerOptions struct { + httpClient *http.Client + timeoutOverride time.Duration +} + +type providerSettings struct { + Name string + Model string + BaseURL string + APIKey string + MaxInputChars int + Timeout time.Duration + HTTPClient *http.Client +} + +func WithHTTPClient(client *http.Client) Option { + return func(opts *providerOptions) { + opts.httpClient = client + } +} + +func WithRequestTimeout(timeout time.Duration) Option { + return func(opts *providerOptions) { + opts.timeoutOverride = timeout + } +} + +func NewProvider(cfg Config, opts ...Option) (Provider, error) { + settings, err := resolveProviderConfig(cfg, true, opts...) + if err != nil { + return nil, err + } + return newProvider(settings) +} + +func CheckProvider(ctx context.Context, cfg Config) CheckResult { + settings, err := resolveProviderConfig(cfg, true, WithRequestTimeout(DefaultProbeTimeout)) + if err != nil { + return CheckResult{ + Provider: normalizedProviderName(cfg.Provider), + Model: strings.TrimSpace(cfg.Model), + BaseURL: strings.TrimSpace(cfg.BaseURL), + Status: "warning", + Warning: err.Error(), + } + } + result := CheckResult{ + Provider: settings.Name, + Model: settings.Model, + BaseURL: settings.BaseURL, + Status: "ok", + } + if !shouldProbe(settings) { + return result + } + provider, err := newProvider(settings) + if err != nil { + result.Status = "warning" + result.Warning = err.Error() + return result + } + probeCtx, cancel := context.WithTimeout(ctx, DefaultProbeTimeout) + defer cancel() + if _, err := provider.Embed(probeCtx, []string{"crawlkit probe"}); err != nil { + result.Status = "warning" + result.Warning = err.Error() + return result + } + result.Probed = true + return result +} + +func resolveProviderConfig(cfg Config, validateAPIKey bool, opts ...Option) (providerSettings, error) { + options := providerOptions{} + for _, opt := range opts { + opt(&options) + } + name := normalizedProviderName(cfg.Provider) + if name == "" { + name = ProviderOpenAI + } + model := strings.TrimSpace(cfg.Model) + if model == "" { + model = defaultModel(name) + } + baseURL := strings.TrimRight(strings.TrimSpace(cfg.BaseURL), "/") + if baseURL == "" { + switch name { + case ProviderOpenAI: + baseURL = DefaultOpenAIBaseURL + case ProviderOllama: + baseURL = DefaultOllamaBaseURL + case ProviderLlamaCpp: + baseURL = DefaultLlamaCppBaseURL + case ProviderOpenAICompatible: + return providerSettings{}, fmt.Errorf("embedding provider %q requires base_url", name) + } + } + timeout := DefaultRequestTimeout + if strings.TrimSpace(cfg.RequestTimeout) != "" { + parsed, err := time.ParseDuration(cfg.RequestTimeout) + if err != nil { + return providerSettings{}, fmt.Errorf("parse embeddings request_timeout: %w", err) + } + if parsed <= 0 { + return providerSettings{}, errors.New("embeddings request_timeout must be positive") + } + timeout = parsed + } + if options.timeoutOverride > 0 && options.timeoutOverride < timeout { + timeout = options.timeoutOverride + } + maxInputChars := cfg.MaxInputChars + if maxInputChars <= 0 { + maxInputChars = DefaultMaxInputChars + } + switch name { + case ProviderOpenAI, ProviderOllama, ProviderLlamaCpp, ProviderOpenAICompatible: + default: + return providerSettings{}, fmt.Errorf("unsupported embedding provider %q", name) + } + apiKey, err := resolveAPIKey(name, cfg.APIKeyEnv, validateAPIKey) + if err != nil { + return providerSettings{}, err + } + client := options.httpClient + if client == nil { + client = &http.Client{Timeout: timeout} + } + if _, err := url.ParseRequestURI(baseURL); err != nil { + return providerSettings{}, fmt.Errorf("invalid embeddings base_url %q: %w", baseURL, err) + } + return providerSettings{ + Name: name, + Model: model, + BaseURL: baseURL, + APIKey: apiKey, + MaxInputChars: maxInputChars, + Timeout: timeout, + HTTPClient: client, + }, nil +} + +func newProvider(settings providerSettings) (Provider, error) { + switch settings.Name { + case ProviderOllama: + return newOllamaProvider(settings), nil + case ProviderOpenAI, ProviderLlamaCpp, ProviderOpenAICompatible: + return newOpenAICompatibleProvider(settings), nil + default: + return nil, fmt.Errorf("unsupported embedding provider %q", settings.Name) + } +} + +func resolveAPIKey(provider, apiKeyEnv string, validate bool) (string, error) { + envName := strings.TrimSpace(apiKeyEnv) + required := provider == ProviderOpenAI + if envName == "" { + if required { + envName = "OPENAI_API_KEY" + } else { + return "", nil + } + } + value := strings.TrimSpace(os.Getenv(envName)) + if value == "" { + if required || validate { + return "", fmt.Errorf("embedding provider %q requires API key env %s", provider, envName) + } + return "", nil + } + return value, nil +} + +func normalizedProviderName(provider string) string { + return strings.ToLower(strings.TrimSpace(provider)) +} + +func defaultModel(provider string) string { + switch provider { + case ProviderOllama, ProviderLlamaCpp: + return DefaultLocalEmbeddingModel + default: + return DefaultOpenAIModel + } +} + +func shouldProbe(settings providerSettings) bool { + switch settings.Name { + case ProviderOllama, ProviderLlamaCpp: + return true + case ProviderOpenAICompatible: + return isLoopbackBaseURL(settings.BaseURL) + default: + return false + } +} + +func isLoopbackBaseURL(rawURL string) bool { + parsed, err := url.Parse(rawURL) + if err != nil { + return false + } + host := parsed.Hostname() + if host == "localhost" { + return true + } + ip := net.ParseIP(host) + return ip != nil && ip.IsLoopback() +} + +func trimInputs(inputs []string, maxChars int) []string { + if maxChars <= 0 { + maxChars = DefaultMaxInputChars + } + out := make([]string, len(inputs)) + for i, input := range inputs { + runes := []rune(input) + if len(runes) > maxChars { + runes = runes[:maxChars] + } + out[i] = string(runes) + } + return out +} + +func inferDimensions(vectors [][]float32) (int, error) { + dimensions := 0 + for _, vector := range vectors { + if len(vector) == 0 { + return 0, errors.New("embedding response contained an empty vector") + } + if dimensions == 0 { + dimensions = len(vector) + continue + } + if len(vector) != dimensions { + return 0, fmt.Errorf("embedding response dimensions mismatch: got %d want %d", len(vector), dimensions) + } + } + return dimensions, nil +} diff --git a/embed/provider_test.go b/embed/provider_test.go new file mode 100644 index 0000000..70857c8 --- /dev/null +++ b/embed/provider_test.go @@ -0,0 +1,453 @@ +package embed + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "testing" + "time" +) + +func TestOllamaProviderEmbeds(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/embed", r.URL.Path) + assert.Equal(t, http.MethodPost, r.Method) + var req ollamaEmbedRequest + assert.NoError(t, json.NewDecoder(r.Body).Decode(&req)) + assert.Equal(t, "nomic-embed-text", req.Model) + assert.Equal(t, []string{"abcd", "xy"}, req.Input) + _, _ = w.Write([]byte(`{"model":"nomic-embed-text","embeddings":[[1,2,3],[4,5,6]]}`)) + })) + defer server.Close() + + provider, err := NewProvider(Config{ + Provider: ProviderOllama, + Model: "nomic-embed-text", + BaseURL: server.URL, + MaxInputChars: 4, + RequestTimeout: "5s", + }) + require.NoError(t, err) + + batch, err := provider.Embed(context.Background(), []string{"abcdef", "xy"}) + require.NoError(t, err) + require.Equal(t, "nomic-embed-text", batch.Model) + require.Equal(t, 3, batch.Dimensions) + require.Equal(t, [][]float32{{1, 2, 3}, {4, 5, 6}}, batch.Vectors) +} + +type assertAPI struct{} +type requireAPI struct{} + +var assert assertAPI +var require requireAPI + +func (assertAPI) Equal(t *testing.T, want, got any) bool { + t.Helper() + if !reflect.DeepEqual(want, got) { + t.Errorf("not equal:\nwant: %#v\n got: %#v", want, got) + return false + } + return true +} + +func (assertAPI) NoError(t *testing.T, err error) bool { + t.Helper() + if err != nil { + t.Errorf("unexpected error: %v", err) + return false + } + return true +} + +func (requireAPI) NoError(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func (requireAPI) Equal(t *testing.T, want, got any) { + t.Helper() + if !reflect.DeepEqual(want, got) { + t.Fatalf("not equal:\nwant: %#v\n got: %#v", want, got) + } +} + +func (requireAPI) Same(t *testing.T, want, got any) { + t.Helper() + if !reflect.ValueOf(want).IsValid() || !reflect.ValueOf(got).IsValid() || + reflect.ValueOf(want).Pointer() != reflect.ValueOf(got).Pointer() { + t.Fatalf("not same:\nwant: %#v\n got: %#v", want, got) + } +} + +func (requireAPI) True(t *testing.T, value bool) { + t.Helper() + if !value { + t.Fatal("expected true") + } +} + +func (requireAPI) False(t *testing.T, value bool) { + t.Helper() + if value { + t.Fatal("expected false") + } +} + +func (requireAPI) Empty(t *testing.T, value string) { + t.Helper() + if value != "" { + t.Fatalf("expected empty string, got %q", value) + } +} + +func (requireAPI) Contains(t *testing.T, value, needle string) { + t.Helper() + if !strings.Contains(value, needle) { + t.Fatalf("expected %q to contain %q", value, needle) + } +} + +func (requireAPI) ErrorContains(t *testing.T, err error, needle string) { + t.Helper() + if err == nil { + t.Fatalf("expected error containing %q, got nil", needle) + } + if !strings.Contains(err.Error(), needle) { + t.Fatalf("expected error containing %q, got %q", needle, err.Error()) + } +} + +func TestOpenAICompatibleProviderEmbedsAndUsesAuth(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/embeddings", r.URL.Path) + assert.Equal(t, "Bearer secret", r.Header.Get("Authorization")) + var req openAIEmbeddingRequest + assert.NoError(t, json.NewDecoder(r.Body).Decode(&req)) + assert.Equal(t, "local-model", req.Model) + assert.Equal(t, []string{"one", "two"}, req.Input) + _, _ = w.Write([]byte(`{ + "model":"local-model", + "data":[ + {"index":1,"embedding":[3,4]}, + {"index":0,"embedding":[1,2]} + ] + }`)) + })) + defer server.Close() + t.Setenv("CRAWLKIT_EMBED_KEY", "secret") + + provider, err := NewProvider(Config{ + Provider: ProviderOpenAICompatible, + Model: "local-model", + BaseURL: server.URL, + APIKeyEnv: "CRAWLKIT_EMBED_KEY", + RequestTimeout: "5s", + }) + require.NoError(t, err) + + batch, err := provider.Embed(context.Background(), []string{"one", "two"}) + require.NoError(t, err) + require.Equal(t, "local-model", batch.Model) + require.Equal(t, 2, batch.Dimensions) + require.Equal(t, [][]float32{{1, 2}, {3, 4}}, batch.Vectors) +} + +func TestProviderFactoryDefaultsAndValidation(t *testing.T) { + t.Setenv("OPENAI_API_KEY", "openai-secret") + + openAI, err := resolveProviderConfig(Config{ + Provider: ProviderOpenAI, + RequestTimeout: "5s", + }, true) + require.NoError(t, err) + require.Equal(t, DefaultOpenAIBaseURL, openAI.BaseURL) + require.Equal(t, DefaultOpenAIModel, openAI.Model) + require.Equal(t, "openai-secret", openAI.APIKey) + + ollama, err := resolveProviderConfig(Config{ + Provider: ProviderOllama, + RequestTimeout: "5s", + }, true) + require.NoError(t, err) + require.Equal(t, DefaultOllamaBaseURL, ollama.BaseURL) + require.Equal(t, DefaultLocalEmbeddingModel, ollama.Model) + + llamaCpp, err := resolveProviderConfig(Config{ + Provider: ProviderLlamaCpp, + RequestTimeout: "5s", + }, true) + require.NoError(t, err) + require.Equal(t, DefaultLlamaCppBaseURL, llamaCpp.BaseURL) + + _, err = resolveProviderConfig(Config{ + Provider: ProviderOpenAICompatible, + RequestTimeout: "5s", + }, true) + require.ErrorContains(t, err, "requires base_url") +} + +func TestProviderFactoryRequiresOpenAIAPIKey(t *testing.T) { + t.Setenv("OPENAI_API_KEY", "") + + _, err := NewProvider(Config{ + Provider: ProviderOpenAI, + RequestTimeout: "5s", + }) + require.ErrorContains(t, err, "requires API key env OPENAI_API_KEY") +} + +func TestProviderFactoryReportsUnsupportedProviderBeforeAPIKey(t *testing.T) { + t.Setenv("MISSING_EMBED_KEY", "") + + _, err := NewProvider(Config{ + Provider: "bogus", + APIKeyEnv: "MISSING_EMBED_KEY", + RequestTimeout: "5s", + }) + require.ErrorContains(t, err, "unsupported embedding provider \"bogus\"") +} + +func TestCheckProviderProbesLocalProvider(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/embed", r.URL.Path) + _, _ = w.Write([]byte(`{"model":"nomic-embed-text","embeddings":[[1,2]]}`)) + })) + defer server.Close() + + result := CheckProvider(context.Background(), Config{ + Provider: ProviderOllama, + Model: "nomic-embed-text", + BaseURL: server.URL, + RequestTimeout: "5s", + }) + require.Equal(t, "ok", result.Status) + require.True(t, result.Probed) + require.Empty(t, result.Warning) + require.Equal(t, server.URL, result.BaseURL) +} + +func TestCheckProviderWarnsOnLocalProbeFailure(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "not ready", http.StatusServiceUnavailable) + })) + defer server.Close() + + result := CheckProvider(context.Background(), Config{ + Provider: ProviderOllama, + Model: "nomic-embed-text", + BaseURL: server.URL, + RequestTimeout: "5s", + }) + require.Equal(t, "warning", result.Status) + require.Contains(t, result.Warning, "HTTP 503") + require.False(t, result.Probed) +} + +func TestProviderExposesRateLimitErrors(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "rate limited", http.StatusTooManyRequests) + })) + defer server.Close() + + provider, err := NewProvider(Config{ + Provider: ProviderOpenAICompatible, + Model: "local-model", + BaseURL: server.URL, + RequestTimeout: "5s", + }) + require.NoError(t, err) + + _, err = provider.Embed(context.Background(), []string{"one"}) + require.ErrorContains(t, err, "HTTP 429") + require.True(t, IsRateLimitError(err)) +} + +func TestProviderRejectsInvalidResponses(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(`{"data":[{"index":0,"embedding":[1]},{"index":1,"embedding":[2,3]}]}`)) + })) + defer server.Close() + + provider, err := NewProvider(Config{ + Provider: ProviderOpenAICompatible, + Model: "local-model", + BaseURL: server.URL, + RequestTimeout: "5s", + }) + require.NoError(t, err) + + _, err = provider.Embed(context.Background(), []string{"one", "two"}) + require.ErrorContains(t, err, "dimensions mismatch") +} + +func TestEmbeddingProvidersHandleEmptyInputsAndIndexErrors(t *testing.T) { + t.Parallel() + + settings := providerSettings{ + Name: ProviderOllama, + Model: "model", + BaseURL: "http://127.0.0.1:1", + MaxInputChars: 10, + HTTPClient: http.DefaultClient, + } + ollama := newOllamaProvider(settings) + batch, err := ollama.Embed(context.Background(), nil) + require.NoError(t, err) + require.Equal(t, "model", batch.Model) + + settings.Name = ProviderOpenAICompatible + openai := newOpenAICompatibleProvider(settings) + batch, err = openai.Embed(context.Background(), nil) + require.NoError(t, err) + require.Equal(t, "model", batch.Model) + + tests := []struct { + name string + body string + inputs []string + want string + }{ + {name: "count", body: `{"data":[]}`, inputs: []string{"one"}, want: "returned 0 vectors for 1 inputs"}, + {name: "range", body: `{"data":[{"index":2,"embedding":[1]}]}`, inputs: []string{"one"}, want: "index 2 out of range"}, + {name: "duplicate", body: `{"data":[{"index":0,"embedding":[1]},{"index":0,"embedding":[2]}]}`, inputs: []string{"one", "two"}, want: "duplicated index 0"}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(tc.body)) + })) + defer server.Close() + provider, err := NewProvider(Config{ + Provider: ProviderOpenAICompatible, + Model: "model", + BaseURL: server.URL, + RequestTimeout: "5s", + }) + require.NoError(t, err) + _, err = provider.Embed(context.Background(), tc.inputs) + require.ErrorContains(t, err, tc.want) + }) + } +} + +func TestProviderOptionsAndProbeDecisions(t *testing.T) { + t.Parallel() + + client := &http.Client{Timeout: time.Second} + settings, err := resolveProviderConfig(Config{ + Provider: ProviderOllama, + BaseURL: "http://127.0.0.1:11434/", + RequestTimeout: "30s", + }, true, WithHTTPClient(client), WithRequestTimeout(50*time.Millisecond)) + require.NoError(t, err) + require.Same(t, client, settings.HTTPClient) + require.Equal(t, 50*time.Millisecond, settings.Timeout) + require.Equal(t, "http://127.0.0.1:11434", settings.BaseURL) + require.True(t, shouldProbe(settings)) + + require.True(t, isLoopbackBaseURL("http://localhost:8080/v1")) + require.True(t, isLoopbackBaseURL("http://[::1]:8080/v1")) + require.False(t, isLoopbackBaseURL("https://api.example.com/v1")) + require.False(t, isLoopbackBaseURL("://bad")) + require.False(t, shouldProbe(providerSettings{Name: ProviderOpenAI})) + require.True(t, shouldProbe(providerSettings{Name: ProviderOpenAICompatible, BaseURL: "http://localhost:8080/v1"})) + require.False(t, shouldProbe(providerSettings{Name: ProviderOpenAICompatible, BaseURL: "https://api.example.com/v1"})) +} + +func TestProviderValidationEdges(t *testing.T) { + t.Parallel() + + _, err := resolveProviderConfig(Config{ + Provider: ProviderOllama, + RequestTimeout: "not-a-duration", + }, true) + require.ErrorContains(t, err, "parse embeddings request_timeout") + + _, err = resolveProviderConfig(Config{ + Provider: ProviderOllama, + RequestTimeout: "0s", + }, true) + require.ErrorContains(t, err, "must be positive") + + _, err = resolveProviderConfig(Config{ + Provider: ProviderOllama, + BaseURL: "://bad", + }, true) + require.ErrorContains(t, err, "invalid embeddings base_url") + + key, err := resolveAPIKey(ProviderOpenAICompatible, "MISSING_EMBED_KEY", false) + require.NoError(t, err) + require.Empty(t, key) + + _, err = newProvider(providerSettings{Name: "bogus"}) + require.ErrorContains(t, err, "unsupported embedding provider") + + require.Equal(t, []string{"abc"}, trimInputs([]string{"abc"}, 0)) + _, err = inferDimensions([][]float32{{}}) + require.ErrorContains(t, err, "empty vector") +} + +func TestOllamaProviderResponseEdges(t *testing.T) { + t.Parallel() + + countServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/embed", r.URL.Path) + _, _ = w.Write([]byte(`{"embeddings":[]}`)) + })) + defer countServer.Close() + + provider := newOllamaProvider(providerSettings{ + HTTPClient: countServer.Client(), + BaseURL: countServer.URL, + Model: "fallback-model", + MaxInputChars: 10, + }) + _, err := provider.Embed(context.Background(), []string{"one"}) + require.ErrorContains(t, err, "returned 0 vectors for 1 inputs") + + modelServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/embed", r.URL.Path) + _, _ = w.Write([]byte(`{"embeddings":[[1,2]]}`)) + })) + defer modelServer.Close() + + provider = newOllamaProvider(providerSettings{ + HTTPClient: modelServer.Client(), + BaseURL: modelServer.URL, + Model: "fallback-model", + MaxInputChars: 10, + }) + batch, err := provider.Embed(context.Background(), []string{"one"}) + require.NoError(t, err) + require.Equal(t, "fallback-model", batch.Model) +} + +func TestCheckProviderSkipsRemoteCompatibleProbe(t *testing.T) { + t.Parallel() + + result := CheckProvider(context.Background(), Config{ + Provider: ProviderOpenAICompatible, + Model: "remote-model", + BaseURL: "https://api.example.com/v1", + RequestTimeout: "5s", + }) + require.Equal(t, "ok", result.Status) + require.False(t, result.Probed) + require.Empty(t, result.Warning) +} diff --git a/vector/vector.go b/vector/vector.go new file mode 100644 index 0000000..1a4247c --- /dev/null +++ b/vector/vector.go @@ -0,0 +1,142 @@ +package vector + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "math" + "sort" +) + +const DefaultRRFK = 60.0 + +type Scored[T any] struct { + Item T + Score float64 +} + +type RRFEntry[T any] struct { + Item T + Score float64 +} + +func EncodeFloat32(values []float32) ([]byte, error) { + buf := bytes.NewBuffer(make([]byte, 0, len(values)*4)) + for _, value := range values { + if err := binary.Write(buf, binary.LittleEndian, value); err != nil { + return nil, fmt.Errorf("encode float32 vector: %w", err) + } + } + return buf.Bytes(), nil +} + +func DecodeFloat32(blob []byte) ([]float32, error) { + if len(blob)%4 != 0 { + return nil, fmt.Errorf("float32 vector blob length %d is not a multiple of 4", len(blob)) + } + out := make([]float32, len(blob)/4) + reader := bytes.NewReader(blob) + for i := range out { + if err := binary.Read(reader, binary.LittleEndian, &out[i]); err != nil { + return nil, fmt.Errorf("decode float32 vector: %w", err) + } + } + return out, nil +} + +func ValidateDimensions(values []float32, dimensions int) error { + if dimensions <= 0 { + return errors.New("dimensions must be positive") + } + if len(values) != dimensions { + return fmt.Errorf("dimensions mismatch: got %d want %d", len(values), dimensions) + } + return nil +} + +func Norm(values []float32) float64 { + var sum float64 + for _, value := range values { + sum += float64(value) * float64(value) + } + return math.Sqrt(sum) +} + +func CosineSimilarity(query []float32, queryNorm float64, candidate []float32) (float64, error) { + if len(candidate) != len(query) { + return 0, fmt.Errorf("dimensions mismatch: got %d want %d", len(candidate), len(query)) + } + if queryNorm == 0 { + return 0, errors.New("query vector is zero") + } + candidateNorm := Norm(candidate) + if candidateNorm == 0 { + return 0, errors.New("candidate vector is zero") + } + var dot float64 + for i := range query { + dot += float64(query[i]) * float64(candidate[i]) + } + return dot / (queryNorm * candidateNorm), nil +} + +func TopK[T any](items []Scored[T], limit int, tieLess func(left, right T) bool) []Scored[T] { + if limit <= 0 || len(items) == 0 { + return nil + } + sorted := append([]Scored[T](nil), items...) + sort.SliceStable(sorted, func(i, j int) bool { + if sorted[i].Score != sorted[j].Score { + return sorted[i].Score > sorted[j].Score + } + if tieLess == nil { + return false + } + return tieLess(sorted[i].Item, sorted[j].Item) + }) + if len(sorted) > limit { + sorted = sorted[:limit] + } + return sorted +} + +func ReciprocalRankFusion[T any](rankings [][]T, ids []func(T) string, weights []float64, k float64) []RRFEntry[T] { + if k <= 0 { + k = DefaultRRFK + } + entries := map[string]*RRFEntry[T]{} + for rankingIndex, ranking := range rankings { + weight := 1.0 + if rankingIndex < len(weights) && weights[rankingIndex] != 0 { + weight = weights[rankingIndex] + } + var idFn func(T) string + if rankingIndex < len(ids) { + idFn = ids[rankingIndex] + } + for index, item := range ranking { + if idFn == nil { + continue + } + id := idFn(item) + if id == "" { + continue + } + entry := entries[id] + if entry == nil { + entry = &RRFEntry[T]{Item: item} + entries[id] = entry + } + entry.Score += weight / (k + float64(index+1)) + } + } + out := make([]RRFEntry[T], 0, len(entries)) + for _, entry := range entries { + out = append(out, *entry) + } + sort.SliceStable(out, func(i, j int) bool { + return out[i].Score > out[j].Score + }) + return out +} diff --git a/vector/vector_test.go b/vector/vector_test.go new file mode 100644 index 0000000..400ac87 --- /dev/null +++ b/vector/vector_test.go @@ -0,0 +1,137 @@ +package vector + +import ( + "math" + "reflect" + "strings" + "testing" +) + +func TestFloat32EncodingRoundTrip(t *testing.T) { + blob, err := EncodeFloat32([]float32{1, -2.5, 3.25}) + require.NoError(t, err) + require.Len(t, blob, 12) + + values, err := DecodeFloat32(blob) + require.NoError(t, err) + require.Equal(t, []float32{1, -2.5, 3.25}, values) + + _, err = DecodeFloat32([]byte{1, 2, 3}) + require.ErrorContains(t, err, "not a multiple of 4") +} + +func TestCosineSimilarityAndDimensions(t *testing.T) { + require.NoError(t, ValidateDimensions([]float32{1, 2}, 2)) + require.ErrorContains(t, ValidateDimensions([]float32{1}, 2), "dimensions mismatch") + require.ErrorContains(t, ValidateDimensions([]float32{1}, 0), "positive") + + query := []float32{1, 0} + score, err := CosineSimilarity(query, Norm(query), []float32{0.5, 0}) + require.NoError(t, err) + require.InDelta(t, 1, score, 0.0001) + + _, err = CosineSimilarity(query, 0, []float32{1, 0}) + require.ErrorContains(t, err, "query vector is zero") + _, err = CosineSimilarity(query, Norm(query), []float32{0, 0}) + require.ErrorContains(t, err, "candidate vector is zero") + _, err = CosineSimilarity(query, Norm(query), []float32{1}) + require.ErrorContains(t, err, "dimensions mismatch") + require.Equal(t, math.Sqrt(5), Norm([]float32{1, 2})) +} + +func TestTopK(t *testing.T) { + items := []Scored[string]{ + {Item: "c", Score: 0.3}, + {Item: "a", Score: 0.5}, + {Item: "b", Score: 0.5}, + } + top := TopK(items, 2, func(left, right string) bool { return left < right }) + require.Equal(t, []Scored[string]{{Item: "a", Score: 0.5}, {Item: "b", Score: 0.5}}, top) + require.Nil(t, TopK(items, 0, nil)) +} + +func TestReciprocalRankFusion(t *testing.T) { + rankings := [][]string{ + {"a", "b"}, + {"b", "c"}, + } + ids := []func(string) string{ + func(value string) string { return value }, + func(value string) string { return value }, + } + results := ReciprocalRankFusion(rankings, ids, []float64{1, 1}, 60) + require.Len(t, results, 3) + require.Equal(t, "b", results[0].Item) + require.Greater(t, results[0].Score, results[1].Score) +} + +type requireAPI struct{} + +var require requireAPI + +func (requireAPI) NoError(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func (requireAPI) Equal(t *testing.T, want, got any) { + t.Helper() + if !reflect.DeepEqual(want, got) { + t.Fatalf("not equal:\nwant: %#v\n got: %#v", want, got) + } +} + +func (requireAPI) Len(t *testing.T, value any, want int) { + t.Helper() + got := reflect.ValueOf(value).Len() + if got != want { + t.Fatalf("len mismatch: got %d want %d", got, want) + } +} + +func (requireAPI) Nil(t *testing.T, value any) { + t.Helper() + if !isNil(value) { + t.Fatalf("expected nil, got %#v", value) + } +} + +func (requireAPI) Greater(t *testing.T, left, right float64) { + t.Helper() + if left <= right { + t.Fatalf("expected %v > %v", left, right) + } +} + +func (requireAPI) InDelta(t *testing.T, want, got, delta float64) { + t.Helper() + diff := math.Abs(want - got) + if diff > delta { + t.Fatalf("not within delta: want %v got %v delta %v", want, got, delta) + } +} + +func (requireAPI) ErrorContains(t *testing.T, err error, needle string) { + t.Helper() + if err == nil { + t.Fatalf("expected error containing %q, got nil", needle) + } + if !strings.Contains(err.Error(), needle) { + t.Fatalf("expected error containing %q, got %q", needle, err.Error()) + } +} + +func isNil(value any) bool { + if value == nil { + return true + } + reflected := reflect.ValueOf(value) + switch reflected.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice: + return reflected.IsNil() + default: + return false + } +}