454 lines
13 KiB
Go
454 lines
13 KiB
Go
package embed
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"reflect"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestOllamaProviderEmbeds(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
assert.Equal(t, "/api/embed", r.URL.Path)
|
|
assert.Equal(t, http.MethodPost, r.Method)
|
|
var req ollamaEmbedRequest
|
|
assert.NoError(t, json.NewDecoder(r.Body).Decode(&req))
|
|
assert.Equal(t, "nomic-embed-text", req.Model)
|
|
assert.Equal(t, []string{"abcd", "xy"}, req.Input)
|
|
_, _ = w.Write([]byte(`{"model":"nomic-embed-text","embeddings":[[1,2,3],[4,5,6]]}`))
|
|
}))
|
|
defer server.Close()
|
|
|
|
provider, err := NewProvider(Config{
|
|
Provider: ProviderOllama,
|
|
Model: "nomic-embed-text",
|
|
BaseURL: server.URL,
|
|
MaxInputChars: 4,
|
|
RequestTimeout: "5s",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
batch, err := provider.Embed(context.Background(), []string{"abcdef", "xy"})
|
|
require.NoError(t, err)
|
|
require.Equal(t, "nomic-embed-text", batch.Model)
|
|
require.Equal(t, 3, batch.Dimensions)
|
|
require.Equal(t, [][]float32{{1, 2, 3}, {4, 5, 6}}, batch.Vectors)
|
|
}
|
|
|
|
type assertAPI struct{}
|
|
type requireAPI struct{}
|
|
|
|
var assert assertAPI
|
|
var require requireAPI
|
|
|
|
func (assertAPI) Equal(t *testing.T, want, got any) bool {
|
|
t.Helper()
|
|
if !reflect.DeepEqual(want, got) {
|
|
t.Errorf("not equal:\nwant: %#v\n got: %#v", want, got)
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (assertAPI) NoError(t *testing.T, err error) bool {
|
|
t.Helper()
|
|
if err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (requireAPI) NoError(t *testing.T, err error) {
|
|
t.Helper()
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
|
|
func (requireAPI) Equal(t *testing.T, want, got any) {
|
|
t.Helper()
|
|
if !reflect.DeepEqual(want, got) {
|
|
t.Fatalf("not equal:\nwant: %#v\n got: %#v", want, got)
|
|
}
|
|
}
|
|
|
|
func (requireAPI) Same(t *testing.T, want, got any) {
|
|
t.Helper()
|
|
if !reflect.ValueOf(want).IsValid() || !reflect.ValueOf(got).IsValid() ||
|
|
reflect.ValueOf(want).Pointer() != reflect.ValueOf(got).Pointer() {
|
|
t.Fatalf("not same:\nwant: %#v\n got: %#v", want, got)
|
|
}
|
|
}
|
|
|
|
func (requireAPI) True(t *testing.T, value bool) {
|
|
t.Helper()
|
|
if !value {
|
|
t.Fatal("expected true")
|
|
}
|
|
}
|
|
|
|
func (requireAPI) False(t *testing.T, value bool) {
|
|
t.Helper()
|
|
if value {
|
|
t.Fatal("expected false")
|
|
}
|
|
}
|
|
|
|
func (requireAPI) Empty(t *testing.T, value string) {
|
|
t.Helper()
|
|
if value != "" {
|
|
t.Fatalf("expected empty string, got %q", value)
|
|
}
|
|
}
|
|
|
|
func (requireAPI) Contains(t *testing.T, value, needle string) {
|
|
t.Helper()
|
|
if !strings.Contains(value, needle) {
|
|
t.Fatalf("expected %q to contain %q", value, needle)
|
|
}
|
|
}
|
|
|
|
func (requireAPI) ErrorContains(t *testing.T, err error, needle string) {
|
|
t.Helper()
|
|
if err == nil {
|
|
t.Fatalf("expected error containing %q, got nil", needle)
|
|
}
|
|
if !strings.Contains(err.Error(), needle) {
|
|
t.Fatalf("expected error containing %q, got %q", needle, err.Error())
|
|
}
|
|
}
|
|
|
|
func TestOpenAICompatibleProviderEmbedsAndUsesAuth(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
assert.Equal(t, "/embeddings", r.URL.Path)
|
|
assert.Equal(t, "Bearer secret", r.Header.Get("Authorization"))
|
|
var req openAIEmbeddingRequest
|
|
assert.NoError(t, json.NewDecoder(r.Body).Decode(&req))
|
|
assert.Equal(t, "local-model", req.Model)
|
|
assert.Equal(t, []string{"one", "two"}, req.Input)
|
|
_, _ = w.Write([]byte(`{
|
|
"model":"local-model",
|
|
"data":[
|
|
{"index":1,"embedding":[3,4]},
|
|
{"index":0,"embedding":[1,2]}
|
|
]
|
|
}`))
|
|
}))
|
|
defer server.Close()
|
|
t.Setenv("CRAWLKIT_EMBED_KEY", "secret")
|
|
|
|
provider, err := NewProvider(Config{
|
|
Provider: ProviderOpenAICompatible,
|
|
Model: "local-model",
|
|
BaseURL: server.URL,
|
|
APIKeyEnv: "CRAWLKIT_EMBED_KEY",
|
|
RequestTimeout: "5s",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
batch, err := provider.Embed(context.Background(), []string{"one", "two"})
|
|
require.NoError(t, err)
|
|
require.Equal(t, "local-model", batch.Model)
|
|
require.Equal(t, 2, batch.Dimensions)
|
|
require.Equal(t, [][]float32{{1, 2}, {3, 4}}, batch.Vectors)
|
|
}
|
|
|
|
func TestProviderFactoryDefaultsAndValidation(t *testing.T) {
|
|
t.Setenv("OPENAI_API_KEY", "openai-secret")
|
|
|
|
openAI, err := resolveProviderConfig(Config{
|
|
Provider: ProviderOpenAI,
|
|
RequestTimeout: "5s",
|
|
}, true)
|
|
require.NoError(t, err)
|
|
require.Equal(t, DefaultOpenAIBaseURL, openAI.BaseURL)
|
|
require.Equal(t, DefaultOpenAIModel, openAI.Model)
|
|
require.Equal(t, "openai-secret", openAI.APIKey)
|
|
|
|
ollama, err := resolveProviderConfig(Config{
|
|
Provider: ProviderOllama,
|
|
RequestTimeout: "5s",
|
|
}, true)
|
|
require.NoError(t, err)
|
|
require.Equal(t, DefaultOllamaBaseURL, ollama.BaseURL)
|
|
require.Equal(t, DefaultLocalEmbeddingModel, ollama.Model)
|
|
|
|
llamaCpp, err := resolveProviderConfig(Config{
|
|
Provider: ProviderLlamaCpp,
|
|
RequestTimeout: "5s",
|
|
}, true)
|
|
require.NoError(t, err)
|
|
require.Equal(t, DefaultLlamaCppBaseURL, llamaCpp.BaseURL)
|
|
|
|
_, err = resolveProviderConfig(Config{
|
|
Provider: ProviderOpenAICompatible,
|
|
RequestTimeout: "5s",
|
|
}, true)
|
|
require.ErrorContains(t, err, "requires base_url")
|
|
}
|
|
|
|
func TestProviderFactoryRequiresOpenAIAPIKey(t *testing.T) {
|
|
t.Setenv("OPENAI_API_KEY", "")
|
|
|
|
_, err := NewProvider(Config{
|
|
Provider: ProviderOpenAI,
|
|
RequestTimeout: "5s",
|
|
})
|
|
require.ErrorContains(t, err, "requires API key env OPENAI_API_KEY")
|
|
}
|
|
|
|
func TestProviderFactoryReportsUnsupportedProviderBeforeAPIKey(t *testing.T) {
|
|
t.Setenv("MISSING_EMBED_KEY", "")
|
|
|
|
_, err := NewProvider(Config{
|
|
Provider: "bogus",
|
|
APIKeyEnv: "MISSING_EMBED_KEY",
|
|
RequestTimeout: "5s",
|
|
})
|
|
require.ErrorContains(t, err, "unsupported embedding provider \"bogus\"")
|
|
}
|
|
|
|
func TestCheckProviderProbesLocalProvider(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
assert.Equal(t, "/api/embed", r.URL.Path)
|
|
_, _ = w.Write([]byte(`{"model":"nomic-embed-text","embeddings":[[1,2]]}`))
|
|
}))
|
|
defer server.Close()
|
|
|
|
result := CheckProvider(context.Background(), Config{
|
|
Provider: ProviderOllama,
|
|
Model: "nomic-embed-text",
|
|
BaseURL: server.URL,
|
|
RequestTimeout: "5s",
|
|
})
|
|
require.Equal(t, "ok", result.Status)
|
|
require.True(t, result.Probed)
|
|
require.Empty(t, result.Warning)
|
|
require.Equal(t, server.URL, result.BaseURL)
|
|
}
|
|
|
|
func TestCheckProviderWarnsOnLocalProbeFailure(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
http.Error(w, "not ready", http.StatusServiceUnavailable)
|
|
}))
|
|
defer server.Close()
|
|
|
|
result := CheckProvider(context.Background(), Config{
|
|
Provider: ProviderOllama,
|
|
Model: "nomic-embed-text",
|
|
BaseURL: server.URL,
|
|
RequestTimeout: "5s",
|
|
})
|
|
require.Equal(t, "warning", result.Status)
|
|
require.Contains(t, result.Warning, "HTTP 503")
|
|
require.False(t, result.Probed)
|
|
}
|
|
|
|
func TestProviderExposesRateLimitErrors(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
http.Error(w, "rate limited", http.StatusTooManyRequests)
|
|
}))
|
|
defer server.Close()
|
|
|
|
provider, err := NewProvider(Config{
|
|
Provider: ProviderOpenAICompatible,
|
|
Model: "local-model",
|
|
BaseURL: server.URL,
|
|
RequestTimeout: "5s",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = provider.Embed(context.Background(), []string{"one"})
|
|
require.ErrorContains(t, err, "HTTP 429")
|
|
require.True(t, IsRateLimitError(err))
|
|
}
|
|
|
|
func TestProviderRejectsInvalidResponses(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = w.Write([]byte(`{"data":[{"index":0,"embedding":[1]},{"index":1,"embedding":[2,3]}]}`))
|
|
}))
|
|
defer server.Close()
|
|
|
|
provider, err := NewProvider(Config{
|
|
Provider: ProviderOpenAICompatible,
|
|
Model: "local-model",
|
|
BaseURL: server.URL,
|
|
RequestTimeout: "5s",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = provider.Embed(context.Background(), []string{"one", "two"})
|
|
require.ErrorContains(t, err, "dimensions mismatch")
|
|
}
|
|
|
|
func TestEmbeddingProvidersHandleEmptyInputsAndIndexErrors(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
settings := providerSettings{
|
|
Name: ProviderOllama,
|
|
Model: "model",
|
|
BaseURL: "http://127.0.0.1:1",
|
|
MaxInputChars: 10,
|
|
HTTPClient: http.DefaultClient,
|
|
}
|
|
ollama := newOllamaProvider(settings)
|
|
batch, err := ollama.Embed(context.Background(), nil)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "model", batch.Model)
|
|
|
|
settings.Name = ProviderOpenAICompatible
|
|
openai := newOpenAICompatibleProvider(settings)
|
|
batch, err = openai.Embed(context.Background(), nil)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "model", batch.Model)
|
|
|
|
tests := []struct {
|
|
name string
|
|
body string
|
|
inputs []string
|
|
want string
|
|
}{
|
|
{name: "count", body: `{"data":[]}`, inputs: []string{"one"}, want: "returned 0 vectors for 1 inputs"},
|
|
{name: "range", body: `{"data":[{"index":2,"embedding":[1]}]}`, inputs: []string{"one"}, want: "index 2 out of range"},
|
|
{name: "duplicate", body: `{"data":[{"index":0,"embedding":[1]},{"index":0,"embedding":[2]}]}`, inputs: []string{"one", "two"}, want: "duplicated index 0"},
|
|
}
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = w.Write([]byte(tc.body))
|
|
}))
|
|
defer server.Close()
|
|
provider, err := NewProvider(Config{
|
|
Provider: ProviderOpenAICompatible,
|
|
Model: "model",
|
|
BaseURL: server.URL,
|
|
RequestTimeout: "5s",
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = provider.Embed(context.Background(), tc.inputs)
|
|
require.ErrorContains(t, err, tc.want)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestProviderOptionsAndProbeDecisions(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
client := &http.Client{Timeout: time.Second}
|
|
settings, err := resolveProviderConfig(Config{
|
|
Provider: ProviderOllama,
|
|
BaseURL: "http://127.0.0.1:11434/",
|
|
RequestTimeout: "30s",
|
|
}, true, WithHTTPClient(client), WithRequestTimeout(50*time.Millisecond))
|
|
require.NoError(t, err)
|
|
require.Same(t, client, settings.HTTPClient)
|
|
require.Equal(t, 50*time.Millisecond, settings.Timeout)
|
|
require.Equal(t, "http://127.0.0.1:11434", settings.BaseURL)
|
|
require.True(t, shouldProbe(settings))
|
|
|
|
require.True(t, isLoopbackBaseURL("http://localhost:8080/v1"))
|
|
require.True(t, isLoopbackBaseURL("http://[::1]:8080/v1"))
|
|
require.False(t, isLoopbackBaseURL("https://api.example.com/v1"))
|
|
require.False(t, isLoopbackBaseURL("://bad"))
|
|
require.False(t, shouldProbe(providerSettings{Name: ProviderOpenAI}))
|
|
require.True(t, shouldProbe(providerSettings{Name: ProviderOpenAICompatible, BaseURL: "http://localhost:8080/v1"}))
|
|
require.False(t, shouldProbe(providerSettings{Name: ProviderOpenAICompatible, BaseURL: "https://api.example.com/v1"}))
|
|
}
|
|
|
|
func TestProviderValidationEdges(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
_, err := resolveProviderConfig(Config{
|
|
Provider: ProviderOllama,
|
|
RequestTimeout: "not-a-duration",
|
|
}, true)
|
|
require.ErrorContains(t, err, "parse embeddings request_timeout")
|
|
|
|
_, err = resolveProviderConfig(Config{
|
|
Provider: ProviderOllama,
|
|
RequestTimeout: "0s",
|
|
}, true)
|
|
require.ErrorContains(t, err, "must be positive")
|
|
|
|
_, err = resolveProviderConfig(Config{
|
|
Provider: ProviderOllama,
|
|
BaseURL: "://bad",
|
|
}, true)
|
|
require.ErrorContains(t, err, "invalid embeddings base_url")
|
|
|
|
key, err := resolveAPIKey(ProviderOpenAICompatible, "MISSING_EMBED_KEY", false)
|
|
require.NoError(t, err)
|
|
require.Empty(t, key)
|
|
|
|
_, err = newProvider(providerSettings{Name: "bogus"})
|
|
require.ErrorContains(t, err, "unsupported embedding provider")
|
|
|
|
require.Equal(t, []string{"abc"}, trimInputs([]string{"abc"}, 0))
|
|
_, err = inferDimensions([][]float32{{}})
|
|
require.ErrorContains(t, err, "empty vector")
|
|
}
|
|
|
|
func TestOllamaProviderResponseEdges(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
countServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
assert.Equal(t, "/api/embed", r.URL.Path)
|
|
_, _ = w.Write([]byte(`{"embeddings":[]}`))
|
|
}))
|
|
defer countServer.Close()
|
|
|
|
provider := newOllamaProvider(providerSettings{
|
|
HTTPClient: countServer.Client(),
|
|
BaseURL: countServer.URL,
|
|
Model: "fallback-model",
|
|
MaxInputChars: 10,
|
|
})
|
|
_, err := provider.Embed(context.Background(), []string{"one"})
|
|
require.ErrorContains(t, err, "returned 0 vectors for 1 inputs")
|
|
|
|
modelServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
assert.Equal(t, "/api/embed", r.URL.Path)
|
|
_, _ = w.Write([]byte(`{"embeddings":[[1,2]]}`))
|
|
}))
|
|
defer modelServer.Close()
|
|
|
|
provider = newOllamaProvider(providerSettings{
|
|
HTTPClient: modelServer.Client(),
|
|
BaseURL: modelServer.URL,
|
|
Model: "fallback-model",
|
|
MaxInputChars: 10,
|
|
})
|
|
batch, err := provider.Embed(context.Background(), []string{"one"})
|
|
require.NoError(t, err)
|
|
require.Equal(t, "fallback-model", batch.Model)
|
|
}
|
|
|
|
func TestCheckProviderSkipsRemoteCompatibleProbe(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
result := CheckProvider(context.Background(), Config{
|
|
Provider: ProviderOpenAICompatible,
|
|
Model: "remote-model",
|
|
BaseURL: "https://api.example.com/v1",
|
|
RequestTimeout: "5s",
|
|
})
|
|
require.Equal(t, "ok", result.Status)
|
|
require.False(t, result.Probed)
|
|
require.Empty(t, result.Warning)
|
|
}
|