feat: add shared embedding and vector helpers
This commit is contained in:
parent
7fbca35339
commit
1cc2c66283
@ -2,6 +2,11 @@
|
||||
|
||||
## Unreleased
|
||||
|
||||
- Add reusable `embed` providers for OpenAI, OpenAI-compatible endpoints,
|
||||
Ollama, and llama.cpp, including probe diagnostics and rate-limit errors.
|
||||
- Add reusable `vector` helpers for float32 blobs, dimension validation,
|
||||
cosine similarity, top-k sorting, and reciprocal-rank fusion.
|
||||
|
||||
## v0.4.2 - 2026-05-08
|
||||
|
||||
- Add snapshot file fingerprints and an incremental import planner/executor so downstream apps can import changed JSONL/Gzip shards without deleting every table.
|
||||
|
||||
@ -25,6 +25,8 @@ See `docs/boundary.md` for the crawlkit-versus-app ownership boundary.
|
||||
- `snapshot`: `manifest.json` plus JSONL/Gzip table snapshot export, file fingerprints, full import, and planned incremental shard import.
|
||||
- `mirror`: clone/init/pull/commit/push helpers for private snapshot repos.
|
||||
- `state`: generic crawler cursor and freshness records.
|
||||
- `embed`: reusable OpenAI-compatible, Ollama, and llama.cpp embedding providers plus local probe diagnostics.
|
||||
- `vector`: float32 vector encoding, dimension validation, cosine scoring, top-k helpers, and reciprocal-rank fusion.
|
||||
- `output`: text/json/log output helpers.
|
||||
- `control`: crawl app metadata, command manifests, status payloads, and
|
||||
database inventory for launchers and automation.
|
||||
|
||||
91
embed/ollama.go
Normal file
91
embed/ollama.go
Normal file
@ -0,0 +1,91 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type ollamaProvider struct {
|
||||
client *http.Client
|
||||
baseURL string
|
||||
model string
|
||||
maxInputChars int
|
||||
}
|
||||
|
||||
type ollamaEmbedRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input []string `json:"input"`
|
||||
}
|
||||
|
||||
type ollamaEmbedResponse struct {
|
||||
Model string `json:"model"`
|
||||
Embeddings [][]float32 `json:"embeddings"`
|
||||
}
|
||||
|
||||
func newOllamaProvider(settings providerSettings) Provider {
|
||||
return &ollamaProvider{
|
||||
client: settings.HTTPClient,
|
||||
baseURL: settings.BaseURL,
|
||||
model: settings.Model,
|
||||
maxInputChars: settings.MaxInputChars,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ollamaProvider) Embed(ctx context.Context, inputs []string) (EmbeddingBatch, error) {
|
||||
if len(inputs) == 0 {
|
||||
return EmbeddingBatch{Model: p.model}, nil
|
||||
}
|
||||
payload := ollamaEmbedRequest{
|
||||
Model: p.model,
|
||||
Input: trimInputs(inputs, p.maxInputChars),
|
||||
}
|
||||
var response ollamaEmbedResponse
|
||||
if err := postJSON(ctx, p.client, p.baseURL+"/api/embed", "", payload, &response); err != nil {
|
||||
return EmbeddingBatch{}, err
|
||||
}
|
||||
if len(response.Embeddings) != len(inputs) {
|
||||
return EmbeddingBatch{}, fmt.Errorf("ollama embedding response returned %d vectors for %d inputs", len(response.Embeddings), len(inputs))
|
||||
}
|
||||
dimensions, err := inferDimensions(response.Embeddings)
|
||||
if err != nil {
|
||||
return EmbeddingBatch{}, err
|
||||
}
|
||||
model := response.Model
|
||||
if model == "" {
|
||||
model = p.model
|
||||
}
|
||||
return EmbeddingBatch{Model: model, Dimensions: dimensions, Vectors: response.Embeddings}, nil
|
||||
}
|
||||
|
||||
func postJSON(ctx context.Context, client *http.Client, endpoint, apiKey string, payload any, target any) error {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal embedding request: %w", err)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("build embedding request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
if apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("embedding request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
msg, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
return &HTTPError{StatusCode: resp.StatusCode, Body: string(msg)}
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(target); err != nil {
|
||||
return fmt.Errorf("decode embedding response: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
82
embed/openai_compatible.go
Normal file
82
embed/openai_compatible.go
Normal file
@ -0,0 +1,82 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type openAICompatibleProvider struct {
|
||||
client *http.Client
|
||||
baseURL string
|
||||
apiKey string
|
||||
model string
|
||||
maxInputChars int
|
||||
}
|
||||
|
||||
type openAIEmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input []string `json:"input"`
|
||||
}
|
||||
|
||||
type openAIEmbeddingResponse struct {
|
||||
Model string `json:"model"`
|
||||
Data []openAIEmbeddingItem `json:"data"`
|
||||
}
|
||||
|
||||
type openAIEmbeddingItem struct {
|
||||
Index *int `json:"index"`
|
||||
Embedding []float32 `json:"embedding"`
|
||||
}
|
||||
|
||||
func newOpenAICompatibleProvider(settings providerSettings) Provider {
|
||||
return &openAICompatibleProvider{
|
||||
client: settings.HTTPClient,
|
||||
baseURL: settings.BaseURL,
|
||||
apiKey: settings.APIKey,
|
||||
model: settings.Model,
|
||||
maxInputChars: settings.MaxInputChars,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *openAICompatibleProvider) Embed(ctx context.Context, inputs []string) (EmbeddingBatch, error) {
|
||||
if len(inputs) == 0 {
|
||||
return EmbeddingBatch{Model: p.model}, nil
|
||||
}
|
||||
payload := openAIEmbeddingRequest{
|
||||
Model: p.model,
|
||||
Input: trimInputs(inputs, p.maxInputChars),
|
||||
}
|
||||
var response openAIEmbeddingResponse
|
||||
if err := postJSON(ctx, p.client, p.baseURL+"/embeddings", p.apiKey, payload, &response); err != nil {
|
||||
return EmbeddingBatch{}, err
|
||||
}
|
||||
if len(response.Data) != len(inputs) {
|
||||
return EmbeddingBatch{}, fmt.Errorf("openai-compatible embedding response returned %d vectors for %d inputs", len(response.Data), len(inputs))
|
||||
}
|
||||
vectors := make([][]float32, len(inputs))
|
||||
seen := make([]bool, len(inputs))
|
||||
for position, item := range response.Data {
|
||||
index := position
|
||||
if item.Index != nil {
|
||||
index = *item.Index
|
||||
}
|
||||
if index < 0 || index >= len(inputs) {
|
||||
return EmbeddingBatch{}, fmt.Errorf("openai-compatible embedding response index %d out of range", index)
|
||||
}
|
||||
if seen[index] {
|
||||
return EmbeddingBatch{}, fmt.Errorf("openai-compatible embedding response duplicated index %d", index)
|
||||
}
|
||||
seen[index] = true
|
||||
vectors[index] = item.Embedding
|
||||
}
|
||||
dimensions, err := inferDimensions(vectors)
|
||||
if err != nil {
|
||||
return EmbeddingBatch{}, err
|
||||
}
|
||||
model := response.Model
|
||||
if model == "" {
|
||||
model = p.model
|
||||
}
|
||||
return EmbeddingBatch{Model: model, Dimensions: dimensions, Vectors: vectors}, nil
|
||||
}
|
||||
317
embed/provider.go
Normal file
317
embed/provider.go
Normal file
@ -0,0 +1,317 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
ProviderOpenAI = "openai"
|
||||
ProviderOllama = "ollama"
|
||||
ProviderLlamaCpp = "llamacpp"
|
||||
ProviderOpenAICompatible = "openai_compatible"
|
||||
|
||||
DefaultOpenAIBaseURL = "https://api.openai.com/v1"
|
||||
DefaultOllamaBaseURL = "http://127.0.0.1:11434"
|
||||
DefaultLlamaCppBaseURL = "http://127.0.0.1:8080/v1"
|
||||
DefaultOpenAIModel = "text-embedding-3-small"
|
||||
DefaultLocalEmbeddingModel = "nomic-embed-text"
|
||||
DefaultBatchSize = 64
|
||||
DefaultMaxInputChars = 12000
|
||||
DefaultRequestTimeout = 2 * time.Minute
|
||||
DefaultProbeTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Provider string
|
||||
Model string
|
||||
BaseURL string
|
||||
APIKeyEnv string
|
||||
RequestTimeout string
|
||||
MaxInputChars int
|
||||
}
|
||||
|
||||
type Provider interface {
|
||||
Embed(ctx context.Context, inputs []string) (EmbeddingBatch, error)
|
||||
}
|
||||
|
||||
type EmbeddingBatch struct {
|
||||
Model string
|
||||
Dimensions int
|
||||
Vectors [][]float32
|
||||
}
|
||||
|
||||
type HTTPError struct {
|
||||
StatusCode int
|
||||
Body string
|
||||
}
|
||||
|
||||
func (e *HTTPError) Error() string {
|
||||
return fmt.Sprintf("embedding request failed with HTTP %d: %s", e.StatusCode, e.Body)
|
||||
}
|
||||
|
||||
func IsRateLimitError(err error) bool {
|
||||
var httpErr *HTTPError
|
||||
return errors.As(err, &httpErr) && httpErr.StatusCode == http.StatusTooManyRequests
|
||||
}
|
||||
|
||||
type CheckResult struct {
|
||||
Provider string
|
||||
Model string
|
||||
BaseURL string
|
||||
Status string
|
||||
Warning string
|
||||
Probed bool
|
||||
}
|
||||
|
||||
type Option func(*providerOptions)
|
||||
|
||||
type providerOptions struct {
|
||||
httpClient *http.Client
|
||||
timeoutOverride time.Duration
|
||||
}
|
||||
|
||||
type providerSettings struct {
|
||||
Name string
|
||||
Model string
|
||||
BaseURL string
|
||||
APIKey string
|
||||
MaxInputChars int
|
||||
Timeout time.Duration
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
func WithHTTPClient(client *http.Client) Option {
|
||||
return func(opts *providerOptions) {
|
||||
opts.httpClient = client
|
||||
}
|
||||
}
|
||||
|
||||
func WithRequestTimeout(timeout time.Duration) Option {
|
||||
return func(opts *providerOptions) {
|
||||
opts.timeoutOverride = timeout
|
||||
}
|
||||
}
|
||||
|
||||
func NewProvider(cfg Config, opts ...Option) (Provider, error) {
|
||||
settings, err := resolveProviderConfig(cfg, true, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newProvider(settings)
|
||||
}
|
||||
|
||||
func CheckProvider(ctx context.Context, cfg Config) CheckResult {
|
||||
settings, err := resolveProviderConfig(cfg, true, WithRequestTimeout(DefaultProbeTimeout))
|
||||
if err != nil {
|
||||
return CheckResult{
|
||||
Provider: normalizedProviderName(cfg.Provider),
|
||||
Model: strings.TrimSpace(cfg.Model),
|
||||
BaseURL: strings.TrimSpace(cfg.BaseURL),
|
||||
Status: "warning",
|
||||
Warning: err.Error(),
|
||||
}
|
||||
}
|
||||
result := CheckResult{
|
||||
Provider: settings.Name,
|
||||
Model: settings.Model,
|
||||
BaseURL: settings.BaseURL,
|
||||
Status: "ok",
|
||||
}
|
||||
if !shouldProbe(settings) {
|
||||
return result
|
||||
}
|
||||
provider, err := newProvider(settings)
|
||||
if err != nil {
|
||||
result.Status = "warning"
|
||||
result.Warning = err.Error()
|
||||
return result
|
||||
}
|
||||
probeCtx, cancel := context.WithTimeout(ctx, DefaultProbeTimeout)
|
||||
defer cancel()
|
||||
if _, err := provider.Embed(probeCtx, []string{"crawlkit probe"}); err != nil {
|
||||
result.Status = "warning"
|
||||
result.Warning = err.Error()
|
||||
return result
|
||||
}
|
||||
result.Probed = true
|
||||
return result
|
||||
}
|
||||
|
||||
func resolveProviderConfig(cfg Config, validateAPIKey bool, opts ...Option) (providerSettings, error) {
|
||||
options := providerOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
name := normalizedProviderName(cfg.Provider)
|
||||
if name == "" {
|
||||
name = ProviderOpenAI
|
||||
}
|
||||
model := strings.TrimSpace(cfg.Model)
|
||||
if model == "" {
|
||||
model = defaultModel(name)
|
||||
}
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(cfg.BaseURL), "/")
|
||||
if baseURL == "" {
|
||||
switch name {
|
||||
case ProviderOpenAI:
|
||||
baseURL = DefaultOpenAIBaseURL
|
||||
case ProviderOllama:
|
||||
baseURL = DefaultOllamaBaseURL
|
||||
case ProviderLlamaCpp:
|
||||
baseURL = DefaultLlamaCppBaseURL
|
||||
case ProviderOpenAICompatible:
|
||||
return providerSettings{}, fmt.Errorf("embedding provider %q requires base_url", name)
|
||||
}
|
||||
}
|
||||
timeout := DefaultRequestTimeout
|
||||
if strings.TrimSpace(cfg.RequestTimeout) != "" {
|
||||
parsed, err := time.ParseDuration(cfg.RequestTimeout)
|
||||
if err != nil {
|
||||
return providerSettings{}, fmt.Errorf("parse embeddings request_timeout: %w", err)
|
||||
}
|
||||
if parsed <= 0 {
|
||||
return providerSettings{}, errors.New("embeddings request_timeout must be positive")
|
||||
}
|
||||
timeout = parsed
|
||||
}
|
||||
if options.timeoutOverride > 0 && options.timeoutOverride < timeout {
|
||||
timeout = options.timeoutOverride
|
||||
}
|
||||
maxInputChars := cfg.MaxInputChars
|
||||
if maxInputChars <= 0 {
|
||||
maxInputChars = DefaultMaxInputChars
|
||||
}
|
||||
switch name {
|
||||
case ProviderOpenAI, ProviderOllama, ProviderLlamaCpp, ProviderOpenAICompatible:
|
||||
default:
|
||||
return providerSettings{}, fmt.Errorf("unsupported embedding provider %q", name)
|
||||
}
|
||||
apiKey, err := resolveAPIKey(name, cfg.APIKeyEnv, validateAPIKey)
|
||||
if err != nil {
|
||||
return providerSettings{}, err
|
||||
}
|
||||
client := options.httpClient
|
||||
if client == nil {
|
||||
client = &http.Client{Timeout: timeout}
|
||||
}
|
||||
if _, err := url.ParseRequestURI(baseURL); err != nil {
|
||||
return providerSettings{}, fmt.Errorf("invalid embeddings base_url %q: %w", baseURL, err)
|
||||
}
|
||||
return providerSettings{
|
||||
Name: name,
|
||||
Model: model,
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
MaxInputChars: maxInputChars,
|
||||
Timeout: timeout,
|
||||
HTTPClient: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newProvider(settings providerSettings) (Provider, error) {
|
||||
switch settings.Name {
|
||||
case ProviderOllama:
|
||||
return newOllamaProvider(settings), nil
|
||||
case ProviderOpenAI, ProviderLlamaCpp, ProviderOpenAICompatible:
|
||||
return newOpenAICompatibleProvider(settings), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported embedding provider %q", settings.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func resolveAPIKey(provider, apiKeyEnv string, validate bool) (string, error) {
|
||||
envName := strings.TrimSpace(apiKeyEnv)
|
||||
required := provider == ProviderOpenAI
|
||||
if envName == "" {
|
||||
if required {
|
||||
envName = "OPENAI_API_KEY"
|
||||
} else {
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
value := strings.TrimSpace(os.Getenv(envName))
|
||||
if value == "" {
|
||||
if required || validate {
|
||||
return "", fmt.Errorf("embedding provider %q requires API key env %s", provider, envName)
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func normalizedProviderName(provider string) string {
|
||||
return strings.ToLower(strings.TrimSpace(provider))
|
||||
}
|
||||
|
||||
func defaultModel(provider string) string {
|
||||
switch provider {
|
||||
case ProviderOllama, ProviderLlamaCpp:
|
||||
return DefaultLocalEmbeddingModel
|
||||
default:
|
||||
return DefaultOpenAIModel
|
||||
}
|
||||
}
|
||||
|
||||
func shouldProbe(settings providerSettings) bool {
|
||||
switch settings.Name {
|
||||
case ProviderOllama, ProviderLlamaCpp:
|
||||
return true
|
||||
case ProviderOpenAICompatible:
|
||||
return isLoopbackBaseURL(settings.BaseURL)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isLoopbackBaseURL(rawURL string) bool {
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
host := parsed.Hostname()
|
||||
if host == "localhost" {
|
||||
return true
|
||||
}
|
||||
ip := net.ParseIP(host)
|
||||
return ip != nil && ip.IsLoopback()
|
||||
}
|
||||
|
||||
func trimInputs(inputs []string, maxChars int) []string {
|
||||
if maxChars <= 0 {
|
||||
maxChars = DefaultMaxInputChars
|
||||
}
|
||||
out := make([]string, len(inputs))
|
||||
for i, input := range inputs {
|
||||
runes := []rune(input)
|
||||
if len(runes) > maxChars {
|
||||
runes = runes[:maxChars]
|
||||
}
|
||||
out[i] = string(runes)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func inferDimensions(vectors [][]float32) (int, error) {
|
||||
dimensions := 0
|
||||
for _, vector := range vectors {
|
||||
if len(vector) == 0 {
|
||||
return 0, errors.New("embedding response contained an empty vector")
|
||||
}
|
||||
if dimensions == 0 {
|
||||
dimensions = len(vector)
|
||||
continue
|
||||
}
|
||||
if len(vector) != dimensions {
|
||||
return 0, fmt.Errorf("embedding response dimensions mismatch: got %d want %d", len(vector), dimensions)
|
||||
}
|
||||
}
|
||||
return dimensions, nil
|
||||
}
|
||||
453
embed/provider_test.go
Normal file
453
embed/provider_test.go
Normal file
@ -0,0 +1,453 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestOllamaProviderEmbeds(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/api/embed", r.URL.Path)
|
||||
assert.Equal(t, http.MethodPost, r.Method)
|
||||
var req ollamaEmbedRequest
|
||||
assert.NoError(t, json.NewDecoder(r.Body).Decode(&req))
|
||||
assert.Equal(t, "nomic-embed-text", req.Model)
|
||||
assert.Equal(t, []string{"abcd", "xy"}, req.Input)
|
||||
_, _ = w.Write([]byte(`{"model":"nomic-embed-text","embeddings":[[1,2,3],[4,5,6]]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider, err := NewProvider(Config{
|
||||
Provider: ProviderOllama,
|
||||
Model: "nomic-embed-text",
|
||||
BaseURL: server.URL,
|
||||
MaxInputChars: 4,
|
||||
RequestTimeout: "5s",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
batch, err := provider.Embed(context.Background(), []string{"abcdef", "xy"})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "nomic-embed-text", batch.Model)
|
||||
require.Equal(t, 3, batch.Dimensions)
|
||||
require.Equal(t, [][]float32{{1, 2, 3}, {4, 5, 6}}, batch.Vectors)
|
||||
}
|
||||
|
||||
type assertAPI struct{}
|
||||
type requireAPI struct{}
|
||||
|
||||
var assert assertAPI
|
||||
var require requireAPI
|
||||
|
||||
func (assertAPI) Equal(t *testing.T, want, got any) bool {
|
||||
t.Helper()
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("not equal:\nwant: %#v\n got: %#v", want, got)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (assertAPI) NoError(t *testing.T, err error) bool {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (requireAPI) NoError(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (requireAPI) Equal(t *testing.T, want, got any) {
|
||||
t.Helper()
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Fatalf("not equal:\nwant: %#v\n got: %#v", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func (requireAPI) Same(t *testing.T, want, got any) {
|
||||
t.Helper()
|
||||
if !reflect.ValueOf(want).IsValid() || !reflect.ValueOf(got).IsValid() ||
|
||||
reflect.ValueOf(want).Pointer() != reflect.ValueOf(got).Pointer() {
|
||||
t.Fatalf("not same:\nwant: %#v\n got: %#v", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func (requireAPI) True(t *testing.T, value bool) {
|
||||
t.Helper()
|
||||
if !value {
|
||||
t.Fatal("expected true")
|
||||
}
|
||||
}
|
||||
|
||||
func (requireAPI) False(t *testing.T, value bool) {
|
||||
t.Helper()
|
||||
if value {
|
||||
t.Fatal("expected false")
|
||||
}
|
||||
}
|
||||
|
||||
func (requireAPI) Empty(t *testing.T, value string) {
|
||||
t.Helper()
|
||||
if value != "" {
|
||||
t.Fatalf("expected empty string, got %q", value)
|
||||
}
|
||||
}
|
||||
|
||||
func (requireAPI) Contains(t *testing.T, value, needle string) {
|
||||
t.Helper()
|
||||
if !strings.Contains(value, needle) {
|
||||
t.Fatalf("expected %q to contain %q", value, needle)
|
||||
}
|
||||
}
|
||||
|
||||
func (requireAPI) ErrorContains(t *testing.T, err error, needle string) {
|
||||
t.Helper()
|
||||
if err == nil {
|
||||
t.Fatalf("expected error containing %q, got nil", needle)
|
||||
}
|
||||
if !strings.Contains(err.Error(), needle) {
|
||||
t.Fatalf("expected error containing %q, got %q", needle, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatibleProviderEmbedsAndUsesAuth(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/embeddings", r.URL.Path)
|
||||
assert.Equal(t, "Bearer secret", r.Header.Get("Authorization"))
|
||||
var req openAIEmbeddingRequest
|
||||
assert.NoError(t, json.NewDecoder(r.Body).Decode(&req))
|
||||
assert.Equal(t, "local-model", req.Model)
|
||||
assert.Equal(t, []string{"one", "two"}, req.Input)
|
||||
_, _ = w.Write([]byte(`{
|
||||
"model":"local-model",
|
||||
"data":[
|
||||
{"index":1,"embedding":[3,4]},
|
||||
{"index":0,"embedding":[1,2]}
|
||||
]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
t.Setenv("CRAWLKIT_EMBED_KEY", "secret")
|
||||
|
||||
provider, err := NewProvider(Config{
|
||||
Provider: ProviderOpenAICompatible,
|
||||
Model: "local-model",
|
||||
BaseURL: server.URL,
|
||||
APIKeyEnv: "CRAWLKIT_EMBED_KEY",
|
||||
RequestTimeout: "5s",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
batch, err := provider.Embed(context.Background(), []string{"one", "two"})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "local-model", batch.Model)
|
||||
require.Equal(t, 2, batch.Dimensions)
|
||||
require.Equal(t, [][]float32{{1, 2}, {3, 4}}, batch.Vectors)
|
||||
}
|
||||
|
||||
func TestProviderFactoryDefaultsAndValidation(t *testing.T) {
|
||||
t.Setenv("OPENAI_API_KEY", "openai-secret")
|
||||
|
||||
openAI, err := resolveProviderConfig(Config{
|
||||
Provider: ProviderOpenAI,
|
||||
RequestTimeout: "5s",
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, DefaultOpenAIBaseURL, openAI.BaseURL)
|
||||
require.Equal(t, DefaultOpenAIModel, openAI.Model)
|
||||
require.Equal(t, "openai-secret", openAI.APIKey)
|
||||
|
||||
ollama, err := resolveProviderConfig(Config{
|
||||
Provider: ProviderOllama,
|
||||
RequestTimeout: "5s",
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, DefaultOllamaBaseURL, ollama.BaseURL)
|
||||
require.Equal(t, DefaultLocalEmbeddingModel, ollama.Model)
|
||||
|
||||
llamaCpp, err := resolveProviderConfig(Config{
|
||||
Provider: ProviderLlamaCpp,
|
||||
RequestTimeout: "5s",
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, DefaultLlamaCppBaseURL, llamaCpp.BaseURL)
|
||||
|
||||
_, err = resolveProviderConfig(Config{
|
||||
Provider: ProviderOpenAICompatible,
|
||||
RequestTimeout: "5s",
|
||||
}, true)
|
||||
require.ErrorContains(t, err, "requires base_url")
|
||||
}
|
||||
|
||||
func TestProviderFactoryRequiresOpenAIAPIKey(t *testing.T) {
|
||||
t.Setenv("OPENAI_API_KEY", "")
|
||||
|
||||
_, err := NewProvider(Config{
|
||||
Provider: ProviderOpenAI,
|
||||
RequestTimeout: "5s",
|
||||
})
|
||||
require.ErrorContains(t, err, "requires API key env OPENAI_API_KEY")
|
||||
}
|
||||
|
||||
func TestProviderFactoryReportsUnsupportedProviderBeforeAPIKey(t *testing.T) {
|
||||
t.Setenv("MISSING_EMBED_KEY", "")
|
||||
|
||||
_, err := NewProvider(Config{
|
||||
Provider: "bogus",
|
||||
APIKeyEnv: "MISSING_EMBED_KEY",
|
||||
RequestTimeout: "5s",
|
||||
})
|
||||
require.ErrorContains(t, err, "unsupported embedding provider \"bogus\"")
|
||||
}
|
||||
|
||||
func TestCheckProviderProbesLocalProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/api/embed", r.URL.Path)
|
||||
_, _ = w.Write([]byte(`{"model":"nomic-embed-text","embeddings":[[1,2]]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
result := CheckProvider(context.Background(), Config{
|
||||
Provider: ProviderOllama,
|
||||
Model: "nomic-embed-text",
|
||||
BaseURL: server.URL,
|
||||
RequestTimeout: "5s",
|
||||
})
|
||||
require.Equal(t, "ok", result.Status)
|
||||
require.True(t, result.Probed)
|
||||
require.Empty(t, result.Warning)
|
||||
require.Equal(t, server.URL, result.BaseURL)
|
||||
}
|
||||
|
||||
func TestCheckProviderWarnsOnLocalProbeFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "not ready", http.StatusServiceUnavailable)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
result := CheckProvider(context.Background(), Config{
|
||||
Provider: ProviderOllama,
|
||||
Model: "nomic-embed-text",
|
||||
BaseURL: server.URL,
|
||||
RequestTimeout: "5s",
|
||||
})
|
||||
require.Equal(t, "warning", result.Status)
|
||||
require.Contains(t, result.Warning, "HTTP 503")
|
||||
require.False(t, result.Probed)
|
||||
}
|
||||
|
||||
func TestProviderExposesRateLimitErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "rate limited", http.StatusTooManyRequests)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider, err := NewProvider(Config{
|
||||
Provider: ProviderOpenAICompatible,
|
||||
Model: "local-model",
|
||||
BaseURL: server.URL,
|
||||
RequestTimeout: "5s",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = provider.Embed(context.Background(), []string{"one"})
|
||||
require.ErrorContains(t, err, "HTTP 429")
|
||||
require.True(t, IsRateLimitError(err))
|
||||
}
|
||||
|
||||
func TestProviderRejectsInvalidResponses(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte(`{"data":[{"index":0,"embedding":[1]},{"index":1,"embedding":[2,3]}]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider, err := NewProvider(Config{
|
||||
Provider: ProviderOpenAICompatible,
|
||||
Model: "local-model",
|
||||
BaseURL: server.URL,
|
||||
RequestTimeout: "5s",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = provider.Embed(context.Background(), []string{"one", "two"})
|
||||
require.ErrorContains(t, err, "dimensions mismatch")
|
||||
}
|
||||
|
||||
func TestEmbeddingProvidersHandleEmptyInputsAndIndexErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
settings := providerSettings{
|
||||
Name: ProviderOllama,
|
||||
Model: "model",
|
||||
BaseURL: "http://127.0.0.1:1",
|
||||
MaxInputChars: 10,
|
||||
HTTPClient: http.DefaultClient,
|
||||
}
|
||||
ollama := newOllamaProvider(settings)
|
||||
batch, err := ollama.Embed(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "model", batch.Model)
|
||||
|
||||
settings.Name = ProviderOpenAICompatible
|
||||
openai := newOpenAICompatibleProvider(settings)
|
||||
batch, err = openai.Embed(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "model", batch.Model)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
inputs []string
|
||||
want string
|
||||
}{
|
||||
{name: "count", body: `{"data":[]}`, inputs: []string{"one"}, want: "returned 0 vectors for 1 inputs"},
|
||||
{name: "range", body: `{"data":[{"index":2,"embedding":[1]}]}`, inputs: []string{"one"}, want: "index 2 out of range"},
|
||||
{name: "duplicate", body: `{"data":[{"index":0,"embedding":[1]},{"index":0,"embedding":[2]}]}`, inputs: []string{"one", "two"}, want: "duplicated index 0"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte(tc.body))
|
||||
}))
|
||||
defer server.Close()
|
||||
provider, err := NewProvider(Config{
|
||||
Provider: ProviderOpenAICompatible,
|
||||
Model: "model",
|
||||
BaseURL: server.URL,
|
||||
RequestTimeout: "5s",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = provider.Embed(context.Background(), tc.inputs)
|
||||
require.ErrorContains(t, err, tc.want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderOptionsAndProbeDecisions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := &http.Client{Timeout: time.Second}
|
||||
settings, err := resolveProviderConfig(Config{
|
||||
Provider: ProviderOllama,
|
||||
BaseURL: "http://127.0.0.1:11434/",
|
||||
RequestTimeout: "30s",
|
||||
}, true, WithHTTPClient(client), WithRequestTimeout(50*time.Millisecond))
|
||||
require.NoError(t, err)
|
||||
require.Same(t, client, settings.HTTPClient)
|
||||
require.Equal(t, 50*time.Millisecond, settings.Timeout)
|
||||
require.Equal(t, "http://127.0.0.1:11434", settings.BaseURL)
|
||||
require.True(t, shouldProbe(settings))
|
||||
|
||||
require.True(t, isLoopbackBaseURL("http://localhost:8080/v1"))
|
||||
require.True(t, isLoopbackBaseURL("http://[::1]:8080/v1"))
|
||||
require.False(t, isLoopbackBaseURL("https://api.example.com/v1"))
|
||||
require.False(t, isLoopbackBaseURL("://bad"))
|
||||
require.False(t, shouldProbe(providerSettings{Name: ProviderOpenAI}))
|
||||
require.True(t, shouldProbe(providerSettings{Name: ProviderOpenAICompatible, BaseURL: "http://localhost:8080/v1"}))
|
||||
require.False(t, shouldProbe(providerSettings{Name: ProviderOpenAICompatible, BaseURL: "https://api.example.com/v1"}))
|
||||
}
|
||||
|
||||
func TestProviderValidationEdges(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := resolveProviderConfig(Config{
|
||||
Provider: ProviderOllama,
|
||||
RequestTimeout: "not-a-duration",
|
||||
}, true)
|
||||
require.ErrorContains(t, err, "parse embeddings request_timeout")
|
||||
|
||||
_, err = resolveProviderConfig(Config{
|
||||
Provider: ProviderOllama,
|
||||
RequestTimeout: "0s",
|
||||
}, true)
|
||||
require.ErrorContains(t, err, "must be positive")
|
||||
|
||||
_, err = resolveProviderConfig(Config{
|
||||
Provider: ProviderOllama,
|
||||
BaseURL: "://bad",
|
||||
}, true)
|
||||
require.ErrorContains(t, err, "invalid embeddings base_url")
|
||||
|
||||
key, err := resolveAPIKey(ProviderOpenAICompatible, "MISSING_EMBED_KEY", false)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, key)
|
||||
|
||||
_, err = newProvider(providerSettings{Name: "bogus"})
|
||||
require.ErrorContains(t, err, "unsupported embedding provider")
|
||||
|
||||
require.Equal(t, []string{"abc"}, trimInputs([]string{"abc"}, 0))
|
||||
_, err = inferDimensions([][]float32{{}})
|
||||
require.ErrorContains(t, err, "empty vector")
|
||||
}
|
||||
|
||||
func TestOllamaProviderResponseEdges(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
countServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/api/embed", r.URL.Path)
|
||||
_, _ = w.Write([]byte(`{"embeddings":[]}`))
|
||||
}))
|
||||
defer countServer.Close()
|
||||
|
||||
provider := newOllamaProvider(providerSettings{
|
||||
HTTPClient: countServer.Client(),
|
||||
BaseURL: countServer.URL,
|
||||
Model: "fallback-model",
|
||||
MaxInputChars: 10,
|
||||
})
|
||||
_, err := provider.Embed(context.Background(), []string{"one"})
|
||||
require.ErrorContains(t, err, "returned 0 vectors for 1 inputs")
|
||||
|
||||
modelServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/api/embed", r.URL.Path)
|
||||
_, _ = w.Write([]byte(`{"embeddings":[[1,2]]}`))
|
||||
}))
|
||||
defer modelServer.Close()
|
||||
|
||||
provider = newOllamaProvider(providerSettings{
|
||||
HTTPClient: modelServer.Client(),
|
||||
BaseURL: modelServer.URL,
|
||||
Model: "fallback-model",
|
||||
MaxInputChars: 10,
|
||||
})
|
||||
batch, err := provider.Embed(context.Background(), []string{"one"})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "fallback-model", batch.Model)
|
||||
}
|
||||
|
||||
func TestCheckProviderSkipsRemoteCompatibleProbe(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result := CheckProvider(context.Background(), Config{
|
||||
Provider: ProviderOpenAICompatible,
|
||||
Model: "remote-model",
|
||||
BaseURL: "https://api.example.com/v1",
|
||||
RequestTimeout: "5s",
|
||||
})
|
||||
require.Equal(t, "ok", result.Status)
|
||||
require.False(t, result.Probed)
|
||||
require.Empty(t, result.Warning)
|
||||
}
|
||||
142
vector/vector.go
Normal file
142
vector/vector.go
Normal file
@ -0,0 +1,142 @@
|
||||
package vector
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
)
|
||||
|
||||
const DefaultRRFK = 60.0
|
||||
|
||||
type Scored[T any] struct {
|
||||
Item T
|
||||
Score float64
|
||||
}
|
||||
|
||||
type RRFEntry[T any] struct {
|
||||
Item T
|
||||
Score float64
|
||||
}
|
||||
|
||||
func EncodeFloat32(values []float32) ([]byte, error) {
|
||||
buf := bytes.NewBuffer(make([]byte, 0, len(values)*4))
|
||||
for _, value := range values {
|
||||
if err := binary.Write(buf, binary.LittleEndian, value); err != nil {
|
||||
return nil, fmt.Errorf("encode float32 vector: %w", err)
|
||||
}
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func DecodeFloat32(blob []byte) ([]float32, error) {
|
||||
if len(blob)%4 != 0 {
|
||||
return nil, fmt.Errorf("float32 vector blob length %d is not a multiple of 4", len(blob))
|
||||
}
|
||||
out := make([]float32, len(blob)/4)
|
||||
reader := bytes.NewReader(blob)
|
||||
for i := range out {
|
||||
if err := binary.Read(reader, binary.LittleEndian, &out[i]); err != nil {
|
||||
return nil, fmt.Errorf("decode float32 vector: %w", err)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func ValidateDimensions(values []float32, dimensions int) error {
|
||||
if dimensions <= 0 {
|
||||
return errors.New("dimensions must be positive")
|
||||
}
|
||||
if len(values) != dimensions {
|
||||
return fmt.Errorf("dimensions mismatch: got %d want %d", len(values), dimensions)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func Norm(values []float32) float64 {
|
||||
var sum float64
|
||||
for _, value := range values {
|
||||
sum += float64(value) * float64(value)
|
||||
}
|
||||
return math.Sqrt(sum)
|
||||
}
|
||||
|
||||
func CosineSimilarity(query []float32, queryNorm float64, candidate []float32) (float64, error) {
|
||||
if len(candidate) != len(query) {
|
||||
return 0, fmt.Errorf("dimensions mismatch: got %d want %d", len(candidate), len(query))
|
||||
}
|
||||
if queryNorm == 0 {
|
||||
return 0, errors.New("query vector is zero")
|
||||
}
|
||||
candidateNorm := Norm(candidate)
|
||||
if candidateNorm == 0 {
|
||||
return 0, errors.New("candidate vector is zero")
|
||||
}
|
||||
var dot float64
|
||||
for i := range query {
|
||||
dot += float64(query[i]) * float64(candidate[i])
|
||||
}
|
||||
return dot / (queryNorm * candidateNorm), nil
|
||||
}
|
||||
|
||||
func TopK[T any](items []Scored[T], limit int, tieLess func(left, right T) bool) []Scored[T] {
|
||||
if limit <= 0 || len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
sorted := append([]Scored[T](nil), items...)
|
||||
sort.SliceStable(sorted, func(i, j int) bool {
|
||||
if sorted[i].Score != sorted[j].Score {
|
||||
return sorted[i].Score > sorted[j].Score
|
||||
}
|
||||
if tieLess == nil {
|
||||
return false
|
||||
}
|
||||
return tieLess(sorted[i].Item, sorted[j].Item)
|
||||
})
|
||||
if len(sorted) > limit {
|
||||
sorted = sorted[:limit]
|
||||
}
|
||||
return sorted
|
||||
}
|
||||
|
||||
func ReciprocalRankFusion[T any](rankings [][]T, ids []func(T) string, weights []float64, k float64) []RRFEntry[T] {
|
||||
if k <= 0 {
|
||||
k = DefaultRRFK
|
||||
}
|
||||
entries := map[string]*RRFEntry[T]{}
|
||||
for rankingIndex, ranking := range rankings {
|
||||
weight := 1.0
|
||||
if rankingIndex < len(weights) && weights[rankingIndex] != 0 {
|
||||
weight = weights[rankingIndex]
|
||||
}
|
||||
var idFn func(T) string
|
||||
if rankingIndex < len(ids) {
|
||||
idFn = ids[rankingIndex]
|
||||
}
|
||||
for index, item := range ranking {
|
||||
if idFn == nil {
|
||||
continue
|
||||
}
|
||||
id := idFn(item)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
entry := entries[id]
|
||||
if entry == nil {
|
||||
entry = &RRFEntry[T]{Item: item}
|
||||
entries[id] = entry
|
||||
}
|
||||
entry.Score += weight / (k + float64(index+1))
|
||||
}
|
||||
}
|
||||
out := make([]RRFEntry[T], 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
out = append(out, *entry)
|
||||
}
|
||||
sort.SliceStable(out, func(i, j int) bool {
|
||||
return out[i].Score > out[j].Score
|
||||
})
|
||||
return out
|
||||
}
|
||||
137
vector/vector_test.go
Normal file
137
vector/vector_test.go
Normal file
@ -0,0 +1,137 @@
|
||||
package vector
|
||||
|
||||
import (
|
||||
"math"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFloat32EncodingRoundTrip(t *testing.T) {
|
||||
blob, err := EncodeFloat32([]float32{1, -2.5, 3.25})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, blob, 12)
|
||||
|
||||
values, err := DecodeFloat32(blob)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []float32{1, -2.5, 3.25}, values)
|
||||
|
||||
_, err = DecodeFloat32([]byte{1, 2, 3})
|
||||
require.ErrorContains(t, err, "not a multiple of 4")
|
||||
}
|
||||
|
||||
func TestCosineSimilarityAndDimensions(t *testing.T) {
|
||||
require.NoError(t, ValidateDimensions([]float32{1, 2}, 2))
|
||||
require.ErrorContains(t, ValidateDimensions([]float32{1}, 2), "dimensions mismatch")
|
||||
require.ErrorContains(t, ValidateDimensions([]float32{1}, 0), "positive")
|
||||
|
||||
query := []float32{1, 0}
|
||||
score, err := CosineSimilarity(query, Norm(query), []float32{0.5, 0})
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, 1, score, 0.0001)
|
||||
|
||||
_, err = CosineSimilarity(query, 0, []float32{1, 0})
|
||||
require.ErrorContains(t, err, "query vector is zero")
|
||||
_, err = CosineSimilarity(query, Norm(query), []float32{0, 0})
|
||||
require.ErrorContains(t, err, "candidate vector is zero")
|
||||
_, err = CosineSimilarity(query, Norm(query), []float32{1})
|
||||
require.ErrorContains(t, err, "dimensions mismatch")
|
||||
require.Equal(t, math.Sqrt(5), Norm([]float32{1, 2}))
|
||||
}
|
||||
|
||||
func TestTopK(t *testing.T) {
|
||||
items := []Scored[string]{
|
||||
{Item: "c", Score: 0.3},
|
||||
{Item: "a", Score: 0.5},
|
||||
{Item: "b", Score: 0.5},
|
||||
}
|
||||
top := TopK(items, 2, func(left, right string) bool { return left < right })
|
||||
require.Equal(t, []Scored[string]{{Item: "a", Score: 0.5}, {Item: "b", Score: 0.5}}, top)
|
||||
require.Nil(t, TopK(items, 0, nil))
|
||||
}
|
||||
|
||||
func TestReciprocalRankFusion(t *testing.T) {
|
||||
rankings := [][]string{
|
||||
{"a", "b"},
|
||||
{"b", "c"},
|
||||
}
|
||||
ids := []func(string) string{
|
||||
func(value string) string { return value },
|
||||
func(value string) string { return value },
|
||||
}
|
||||
results := ReciprocalRankFusion(rankings, ids, []float64{1, 1}, 60)
|
||||
require.Len(t, results, 3)
|
||||
require.Equal(t, "b", results[0].Item)
|
||||
require.Greater(t, results[0].Score, results[1].Score)
|
||||
}
|
||||
|
||||
type requireAPI struct{}
|
||||
|
||||
var require requireAPI
|
||||
|
||||
func (requireAPI) NoError(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (requireAPI) Equal(t *testing.T, want, got any) {
|
||||
t.Helper()
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Fatalf("not equal:\nwant: %#v\n got: %#v", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func (requireAPI) Len(t *testing.T, value any, want int) {
|
||||
t.Helper()
|
||||
got := reflect.ValueOf(value).Len()
|
||||
if got != want {
|
||||
t.Fatalf("len mismatch: got %d want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func (requireAPI) Nil(t *testing.T, value any) {
|
||||
t.Helper()
|
||||
if !isNil(value) {
|
||||
t.Fatalf("expected nil, got %#v", value)
|
||||
}
|
||||
}
|
||||
|
||||
func (requireAPI) Greater(t *testing.T, left, right float64) {
|
||||
t.Helper()
|
||||
if left <= right {
|
||||
t.Fatalf("expected %v > %v", left, right)
|
||||
}
|
||||
}
|
||||
|
||||
func (requireAPI) InDelta(t *testing.T, want, got, delta float64) {
|
||||
t.Helper()
|
||||
diff := math.Abs(want - got)
|
||||
if diff > delta {
|
||||
t.Fatalf("not within delta: want %v got %v delta %v", want, got, delta)
|
||||
}
|
||||
}
|
||||
|
||||
func (requireAPI) ErrorContains(t *testing.T, err error, needle string) {
|
||||
t.Helper()
|
||||
if err == nil {
|
||||
t.Fatalf("expected error containing %q, got nil", needle)
|
||||
}
|
||||
if !strings.Contains(err.Error(), needle) {
|
||||
t.Fatalf("expected error containing %q, got %q", needle, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func isNil(value any) bool {
|
||||
if value == nil {
|
||||
return true
|
||||
}
|
||||
reflected := reflect.ValueOf(value)
|
||||
switch reflected.Kind() {
|
||||
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
|
||||
return reflected.IsNil()
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user