feat(embed): add local embedding providers (#32)

This commit is contained in:
MrBrain 2026-04-22 13:06:39 +08:00 committed by GitHub
parent aa74be7b79
commit 2f07416702
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 889 additions and 12 deletions

View File

@ -7,6 +7,7 @@ All notable changes to `discrawl` will be documented in this file.
- Git-backed snapshot imports are now much faster on large archives by using import-only SQLite pragmas and bulk-load FTS5 settings during search index rebuilds
- `messages` and `mentions` now use composite read-path indexes so larger archives spend less time sorting/filtering common guild, channel, and author queries
- normalized message text is now sanitized before it reaches SQLite and FTS5, repairing malformed UTF-8 and stripping invisible/control-character noise that can poison search content
- local embedding providers now support OpenAI-compatible endpoints, Ollama, and llama.cpp, and `doctor` can probe the configured provider before you queue vectors
## 0.3.0 - 2026-04-21

View File

@ -12,6 +12,7 @@ import (
"github.com/steipete/discrawl/internal/config"
"github.com/steipete/discrawl/internal/discord"
"github.com/steipete/discrawl/internal/embed"
"github.com/steipete/discrawl/internal/store"
"github.com/steipete/discrawl/internal/syncer"
)
@ -187,6 +188,21 @@ func (r *runtime) runDoctor(args []string) error {
report["share_auto_update"] = cfg.Share.AutoUpdate
report["share_stale_after"] = cfg.Share.StaleAfter
}
if cfg.Search.Embeddings.Enabled {
check := embed.CheckProvider(r.ctx, cfg.Search.Embeddings)
report["embeddings"] = check.Status
report["embeddings_provider"] = check.Provider
report["embeddings_model"] = check.Model
report["embeddings_base_url"] = check.BaseURL
if check.Probed {
report["embeddings_probe"] = "ok"
}
if check.Warning != "" {
report["embeddings_warning"] = check.Warning
}
} else {
report["embeddings"] = "disabled"
}
token, err := config.ResolveDiscordToken(cfg)
if err != nil {
if cfg.Discord.TokenSource == "none" && cfg.ShareEnabled() {

View File

@ -4,6 +4,8 @@ import (
"bytes"
"context"
"log/slog"
"net/http"
"net/http/httptest"
"os"
"os/exec"
"path/filepath"
@ -671,6 +673,93 @@ func TestRuntimeInitSyncTailAndDoctor(t *testing.T) {
require.Contains(t, out.String(), "discord_auth=ok")
}
func TestDoctorChecksEnabledLocalEmbeddingProvider(t *testing.T) {
ctx := context.Background()
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.toml")
dbPath := filepath.Join(dir, "discrawl.db")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/api/embed", r.URL.Path)
_, _ = w.Write([]byte(`{"model":"nomic-embed-text","embeddings":[[1,2,3]]}`))
}))
defer server.Close()
cfg := config.Default()
cfg.DBPath = dbPath
cfg.Search.Embeddings.Enabled = true
cfg.Search.Embeddings.Provider = "ollama"
cfg.Search.Embeddings.Model = "nomic-embed-text"
cfg.Search.Embeddings.BaseURL = server.URL
require.NoError(t, config.Write(cfgPath, cfg))
var out bytes.Buffer
rt := &runtime{
ctx: ctx,
configPath: cfgPath,
stdout: &out,
stderr: &bytes.Buffer{},
logger: discardLogger(),
}
require.NoError(t, rt.runDoctor(nil))
require.Contains(t, out.String(), "embeddings=ok")
require.Contains(t, out.String(), "embeddings_provider=ollama")
require.Contains(t, out.String(), "embeddings_probe=ok")
}
func TestDoctorReportsEmbeddingProviderWarningsNonFatally(t *testing.T) {
ctx := context.Background()
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.toml")
dbPath := filepath.Join(dir, "discrawl.db")
t.Setenv("OPENAI_API_KEY", "")
cfg := config.Default()
cfg.DBPath = dbPath
cfg.Search.Embeddings.Enabled = true
cfg.Search.Embeddings.Provider = "openai"
require.NoError(t, config.Write(cfgPath, cfg))
var out bytes.Buffer
rt := &runtime{
ctx: ctx,
configPath: cfgPath,
stdout: &out,
stderr: &bytes.Buffer{},
logger: discardLogger(),
}
require.NoError(t, rt.runDoctor(nil))
require.Contains(t, out.String(), "embeddings=warning")
require.Contains(t, out.String(), "embeddings_warning=embedding provider \"openai\" requires API key env OPENAI_API_KEY")
}
func TestDoctorReportsUnsupportedEmbeddingProviderNonFatally(t *testing.T) {
ctx := context.Background()
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.toml")
dbPath := filepath.Join(dir, "discrawl.db")
cfg := config.Default()
cfg.DBPath = dbPath
cfg.Search.Embeddings.Enabled = true
cfg.Search.Embeddings.Provider = "bogus"
cfg.Search.Embeddings.Model = "custom"
cfg.Search.Embeddings.APIKeyEnv = ""
require.NoError(t, config.Write(cfgPath, cfg))
var out bytes.Buffer
rt := &runtime{
ctx: ctx,
configPath: cfgPath,
stdout: &out,
stderr: &bytes.Buffer{},
logger: discardLogger(),
}
require.NoError(t, rt.runDoctor(nil))
require.Contains(t, out.String(), "embeddings=warning")
require.Contains(t, out.String(), "embeddings_warning=unsupported embedding provider \"bogus\"")
}
func TestRuntimeConfiguresAttachmentTextOnSyncer(t *testing.T) {
ctx := context.Background()
dir := t.TempDir()

View File

@ -9,6 +9,7 @@ import (
"runtime"
"sort"
"strings"
"time"
"github.com/pelletier/go-toml/v2"
)
@ -60,11 +61,14 @@ type ShareConfig struct {
}
type EmbeddingsConfig struct {
Enabled bool `toml:"enabled"`
Provider string `toml:"provider"`
Model string `toml:"model"`
APIKeyEnv string `toml:"api_key_env"`
BatchSize int `toml:"batch_size"`
Enabled bool `toml:"enabled"`
Provider string `toml:"provider"`
Model string `toml:"model"`
BaseURL string `toml:"base_url"`
APIKeyEnv string `toml:"api_key_env"`
BatchSize int `toml:"batch_size"`
MaxInputChars int `toml:"max_input_chars"`
RequestTimeout string `toml:"request_timeout"`
}
type TokenResolution struct {
@ -120,11 +124,13 @@ func Default() Config {
Search: SearchConfig{
DefaultMode: "fts",
Embeddings: EmbeddingsConfig{
Enabled: false,
Provider: "openai",
Model: "text-embedding-3-small",
APIKeyEnv: "OPENAI_API_KEY",
BatchSize: 64,
Enabled: false,
Provider: "openai",
Model: "text-embedding-3-small",
APIKeyEnv: "OPENAI_API_KEY",
BatchSize: 64,
MaxInputChars: 12000,
RequestTimeout: "2m",
},
},
Share: ShareConfig{
@ -239,15 +245,28 @@ func (c *Config) Normalize() error {
if c.Search.DefaultMode == "" {
c.Search.DefaultMode = "fts"
}
c.Search.Embeddings.Provider = strings.ToLower(strings.TrimSpace(c.Search.Embeddings.Provider))
c.Search.Embeddings.Model = strings.TrimSpace(c.Search.Embeddings.Model)
c.Search.Embeddings.BaseURL = strings.TrimRight(strings.TrimSpace(c.Search.Embeddings.BaseURL), "/")
c.Search.Embeddings.APIKeyEnv = strings.TrimSpace(c.Search.Embeddings.APIKeyEnv)
c.Search.Embeddings.RequestTimeout = strings.TrimSpace(c.Search.Embeddings.RequestTimeout)
if c.Search.Embeddings.Provider == "" {
c.Search.Embeddings.Provider = "openai"
}
if c.Search.Embeddings.Model == "" {
c.Search.Embeddings.Model = "text-embedding-3-small"
switch strings.ToLower(strings.TrimSpace(c.Search.Embeddings.Provider)) {
case "ollama", "llamacpp":
c.Search.Embeddings.Model = "nomic-embed-text"
default:
c.Search.Embeddings.Model = "text-embedding-3-small"
}
}
if c.Search.Embeddings.APIKeyEnv == "" {
if c.Search.Embeddings.APIKeyEnv == "" && c.Search.Embeddings.Provider == "openai" {
c.Search.Embeddings.APIKeyEnv = "OPENAI_API_KEY"
}
if (c.Search.Embeddings.Provider == "ollama" || c.Search.Embeddings.Provider == "llamacpp") && c.Search.Embeddings.APIKeyEnv == "OPENAI_API_KEY" {
c.Search.Embeddings.APIKeyEnv = ""
}
if c.Search.Embeddings.BatchSize <= 0 {
c.Search.Embeddings.BatchSize = 64
}
@ -260,6 +279,17 @@ func (c *Config) Normalize() error {
if c.Share.StaleAfter == "" {
c.Share.StaleAfter = "15m"
}
if c.Search.Embeddings.MaxInputChars <= 0 {
c.Search.Embeddings.MaxInputChars = 12000
}
if c.Search.Embeddings.RequestTimeout == "" {
c.Search.Embeddings.RequestTimeout = "2m"
}
if timeout, err := time.ParseDuration(c.Search.Embeddings.RequestTimeout); err != nil {
return fmt.Errorf("parse search.embeddings.request_timeout: %w", err)
} else if timeout <= 0 {
return errors.New("search.embeddings.request_timeout must be positive")
}
c.GuildIDs = uniqueStrings(c.GuildIDs)
return nil
}

View File

@ -29,6 +29,13 @@ func TestNormalizeFillsDefaults(t *testing.T) {
require.False(t, cfg.ShareEnabled())
cfg.Share.Remote = "git@example.com:org/archive.git"
require.True(t, cfg.ShareEnabled())
require.Equal(t, "openai", cfg.Search.Embeddings.Provider)
require.Equal(t, "text-embedding-3-small", cfg.Search.Embeddings.Model)
require.Empty(t, cfg.Search.Embeddings.BaseURL)
require.Equal(t, "OPENAI_API_KEY", cfg.Search.Embeddings.APIKeyEnv)
require.Equal(t, 64, cfg.Search.Embeddings.BatchSize)
require.Equal(t, 12000, cfg.Search.Embeddings.MaxInputChars)
require.Equal(t, "2m", cfg.Search.Embeddings.RequestTimeout)
}
func TestResolveDiscordTokenPrefersOpenClaw(t *testing.T) {
@ -138,6 +145,78 @@ func TestWriteAndLoadRoundTrip(t *testing.T) {
require.True(t, *loaded.Sync.AttachmentText)
}
func TestNormalizeEmbeddingProviderDefaults(t *testing.T) {
t.Parallel()
cfg := Default()
cfg.Search.Embeddings.Provider = "OLLAMA"
require.NoError(t, cfg.Normalize())
require.Equal(t, "ollama", cfg.Search.Embeddings.Provider)
require.Equal(t, "text-embedding-3-small", cfg.Search.Embeddings.Model)
require.Empty(t, cfg.Search.Embeddings.APIKeyEnv)
require.Empty(t, cfg.Search.Embeddings.BaseURL)
require.Equal(t, "2m", cfg.Search.Embeddings.RequestTimeout)
cfg = Config{}
cfg.Search.Embeddings.Provider = "llamacpp"
require.NoError(t, cfg.Normalize())
require.Equal(t, "nomic-embed-text", cfg.Search.Embeddings.Model)
require.Empty(t, cfg.Search.Embeddings.APIKeyEnv)
cfg = Config{}
cfg.Search.Embeddings.Provider = "openai_compatible"
cfg.Search.Embeddings.BaseURL = " http://127.0.0.1:9999/v1/ "
require.NoError(t, cfg.Normalize())
require.Equal(t, "openai_compatible", cfg.Search.Embeddings.Provider)
require.Equal(t, "http://127.0.0.1:9999/v1", cfg.Search.Embeddings.BaseURL)
require.Equal(t, "text-embedding-3-small", cfg.Search.Embeddings.Model)
require.Empty(t, cfg.Search.Embeddings.APIKeyEnv)
cfg = Config{}
cfg.Search.Embeddings.Provider = "openai_compatible"
cfg.Search.Embeddings.APIKeyEnv = "OPENAI_API_KEY"
require.NoError(t, cfg.Normalize())
require.Equal(t, "OPENAI_API_KEY", cfg.Search.Embeddings.APIKeyEnv)
}
func TestLoadLegacyEmbeddingConfigDefaults(t *testing.T) {
t.Parallel()
path := filepath.Join(t.TempDir(), "config.toml")
require.NoError(t, os.WriteFile(path, []byte(`
db_path = "/tmp/discrawl.db"
cache_dir = "/tmp/discrawl-cache"
log_dir = "/tmp/discrawl-logs"
[search.embeddings]
enabled = true
provider = "openai"
model = "text-embedding-3-small"
api_key_env = "OPENAI_API_KEY"
batch_size = 64
`), 0o600))
cfg, err := Load(path)
require.NoError(t, err)
require.True(t, cfg.Search.Embeddings.Enabled)
require.Equal(t, "openai", cfg.Search.Embeddings.Provider)
require.Empty(t, cfg.Search.Embeddings.BaseURL)
require.Equal(t, 12000, cfg.Search.Embeddings.MaxInputChars)
require.Equal(t, "2m", cfg.Search.Embeddings.RequestTimeout)
}
func TestNormalizeRejectsInvalidEmbeddingTimeout(t *testing.T) {
t.Parallel()
cfg := Default()
cfg.Search.Embeddings.RequestTimeout = "0s"
require.ErrorContains(t, cfg.Normalize(), "must be positive")
cfg = Default()
cfg.Search.Embeddings.RequestTimeout = "soon"
require.ErrorContains(t, cfg.Normalize(), "parse search.embeddings.request_timeout")
}
func TestAttachmentTextExplicitFalseSurvivesNormalize(t *testing.T) {
t.Parallel()

91
internal/embed/ollama.go Normal file
View File

@ -0,0 +1,91 @@
package embed
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
)
type ollamaProvider struct {
client *http.Client
baseURL string
model string
maxInputChars int
}
type ollamaEmbedRequest struct {
Model string `json:"model"`
Input []string `json:"input"`
}
type ollamaEmbedResponse struct {
Model string `json:"model"`
Embeddings [][]float32 `json:"embeddings"`
}
func newOllamaProvider(settings providerSettings) Provider {
return &ollamaProvider{
client: settings.HTTPClient,
baseURL: settings.BaseURL,
model: settings.Model,
maxInputChars: settings.MaxInputChars,
}
}
func (p *ollamaProvider) Embed(ctx context.Context, inputs []string) (EmbeddingBatch, error) {
if len(inputs) == 0 {
return EmbeddingBatch{Model: p.model}, nil
}
payload := ollamaEmbedRequest{
Model: p.model,
Input: trimInputs(inputs, p.maxInputChars),
}
var response ollamaEmbedResponse
if err := postJSON(ctx, p.client, p.baseURL+"/api/embed", "", payload, &response); err != nil {
return EmbeddingBatch{}, err
}
if len(response.Embeddings) != len(inputs) {
return EmbeddingBatch{}, fmt.Errorf("ollama embedding response returned %d vectors for %d inputs", len(response.Embeddings), len(inputs))
}
dimensions, err := inferDimensions(response.Embeddings)
if err != nil {
return EmbeddingBatch{}, err
}
model := response.Model
if model == "" {
model = p.model
}
return EmbeddingBatch{Model: model, Dimensions: dimensions, Vectors: response.Embeddings}, nil
}
func postJSON(ctx context.Context, client *http.Client, endpoint, apiKey string, payload any, target any) error {
body, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshal embedding request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return fmt.Errorf("build embedding request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
if apiKey != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("embedding request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
msg, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return fmt.Errorf("embedding request failed with HTTP %d: %s", resp.StatusCode, string(msg))
}
if err := json.NewDecoder(resp.Body).Decode(target); err != nil {
return fmt.Errorf("decode embedding response: %w", err)
}
return nil
}

View File

@ -0,0 +1,82 @@
package embed
import (
"context"
"fmt"
"net/http"
)
type openAICompatibleProvider struct {
client *http.Client
baseURL string
apiKey string
model string
maxInputChars int
}
type openAIEmbeddingRequest struct {
Model string `json:"model"`
Input []string `json:"input"`
}
type openAIEmbeddingResponse struct {
Model string `json:"model"`
Data []openAIEmbeddingItem `json:"data"`
}
type openAIEmbeddingItem struct {
Index *int `json:"index"`
Embedding []float32 `json:"embedding"`
}
func newOpenAICompatibleProvider(settings providerSettings) Provider {
return &openAICompatibleProvider{
client: settings.HTTPClient,
baseURL: settings.BaseURL,
apiKey: settings.APIKey,
model: settings.Model,
maxInputChars: settings.MaxInputChars,
}
}
func (p *openAICompatibleProvider) Embed(ctx context.Context, inputs []string) (EmbeddingBatch, error) {
if len(inputs) == 0 {
return EmbeddingBatch{Model: p.model}, nil
}
payload := openAIEmbeddingRequest{
Model: p.model,
Input: trimInputs(inputs, p.maxInputChars),
}
var response openAIEmbeddingResponse
if err := postJSON(ctx, p.client, p.baseURL+"/embeddings", p.apiKey, payload, &response); err != nil {
return EmbeddingBatch{}, err
}
if len(response.Data) != len(inputs) {
return EmbeddingBatch{}, fmt.Errorf("openai-compatible embedding response returned %d vectors for %d inputs", len(response.Data), len(inputs))
}
vectors := make([][]float32, len(inputs))
seen := make([]bool, len(inputs))
for position, item := range response.Data {
index := position
if item.Index != nil {
index = *item.Index
}
if index < 0 || index >= len(inputs) {
return EmbeddingBatch{}, fmt.Errorf("openai-compatible embedding response index %d out of range", index)
}
if seen[index] {
return EmbeddingBatch{}, fmt.Errorf("openai-compatible embedding response duplicated index %d", index)
}
seen[index] = true
vectors[index] = item.Embedding
}
dimensions, err := inferDimensions(vectors)
if err != nil {
return EmbeddingBatch{}, err
}
model := response.Model
if model == "" {
model = p.model
}
return EmbeddingBatch{Model: model, Dimensions: dimensions, Vectors: vectors}, nil
}

296
internal/embed/provider.go Normal file
View File

@ -0,0 +1,296 @@
package embed
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"os"
"strings"
"time"
"github.com/steipete/discrawl/internal/config"
)
const (
ProviderOpenAI = "openai"
ProviderOllama = "ollama"
ProviderLlamaCpp = "llamacpp"
ProviderOpenAICompatible = "openai_compatible"
DefaultOpenAIBaseURL = "https://api.openai.com/v1"
DefaultOllamaBaseURL = "http://127.0.0.1:11434"
DefaultLlamaCppBaseURL = "http://127.0.0.1:8080/v1"
DefaultOpenAIModel = "text-embedding-3-small"
DefaultLocalEmbeddingModel = "nomic-embed-text"
DefaultBatchSize = 64
DefaultMaxInputChars = 12000
DefaultRequestTimeout = 2 * time.Minute
DefaultProbeTimeout = 2 * time.Second
)
// Provider is the narrow embedding surface used by later queue/search work.
type Provider interface {
Embed(ctx context.Context, inputs []string) (EmbeddingBatch, error)
}
type EmbeddingBatch struct {
Model string
Dimensions int
Vectors [][]float32
}
type CheckResult struct {
Provider string
Model string
BaseURL string
Status string
Warning string
Probed bool
}
type Option func(*providerOptions)
type providerOptions struct {
httpClient *http.Client
timeoutOverride time.Duration
}
type providerSettings struct {
Name string
Model string
BaseURL string
APIKey string
MaxInputChars int
Timeout time.Duration
HTTPClient *http.Client
}
func WithHTTPClient(client *http.Client) Option {
return func(opts *providerOptions) {
opts.httpClient = client
}
}
func WithRequestTimeout(timeout time.Duration) Option {
return func(opts *providerOptions) {
opts.timeoutOverride = timeout
}
}
func NewProvider(cfg config.EmbeddingsConfig, opts ...Option) (Provider, error) {
settings, err := resolveProviderConfig(cfg, true, opts...)
if err != nil {
return nil, err
}
return newProvider(settings)
}
func CheckProvider(ctx context.Context, cfg config.EmbeddingsConfig) CheckResult {
settings, err := resolveProviderConfig(cfg, true, WithRequestTimeout(DefaultProbeTimeout))
if err != nil {
return CheckResult{
Provider: normalizedProviderName(cfg.Provider),
Model: strings.TrimSpace(cfg.Model),
BaseURL: strings.TrimSpace(cfg.BaseURL),
Status: "warning",
Warning: err.Error(),
}
}
result := CheckResult{
Provider: settings.Name,
Model: settings.Model,
BaseURL: settings.BaseURL,
Status: "ok",
}
if !shouldProbe(settings) {
return result
}
provider, err := newProvider(settings)
if err != nil {
result.Status = "warning"
result.Warning = err.Error()
return result
}
probeCtx, cancel := context.WithTimeout(ctx, DefaultProbeTimeout)
defer cancel()
if _, err := provider.Embed(probeCtx, []string{"discrawl probe"}); err != nil {
result.Status = "warning"
result.Warning = err.Error()
return result
}
result.Probed = true
return result
}
func resolveProviderConfig(cfg config.EmbeddingsConfig, validateAPIKey bool, opts ...Option) (providerSettings, error) {
options := providerOptions{}
for _, opt := range opts {
opt(&options)
}
name := normalizedProviderName(cfg.Provider)
if name == "" {
name = ProviderOpenAI
}
model := strings.TrimSpace(cfg.Model)
if model == "" {
model = defaultModel(name)
}
baseURL := strings.TrimRight(strings.TrimSpace(cfg.BaseURL), "/")
if baseURL == "" {
switch name {
case ProviderOpenAI:
baseURL = DefaultOpenAIBaseURL
case ProviderOllama:
baseURL = DefaultOllamaBaseURL
case ProviderLlamaCpp:
baseURL = DefaultLlamaCppBaseURL
case ProviderOpenAICompatible:
return providerSettings{}, fmt.Errorf("embedding provider %q requires base_url", name)
}
}
timeout := DefaultRequestTimeout
if strings.TrimSpace(cfg.RequestTimeout) != "" {
parsed, err := time.ParseDuration(cfg.RequestTimeout)
if err != nil {
return providerSettings{}, fmt.Errorf("parse embeddings request_timeout: %w", err)
}
if parsed <= 0 {
return providerSettings{}, errors.New("embeddings request_timeout must be positive")
}
timeout = parsed
}
if options.timeoutOverride > 0 && options.timeoutOverride < timeout {
timeout = options.timeoutOverride
}
maxInputChars := cfg.MaxInputChars
if maxInputChars <= 0 {
maxInputChars = DefaultMaxInputChars
}
switch name {
case ProviderOpenAI, ProviderOllama, ProviderLlamaCpp, ProviderOpenAICompatible:
default:
return providerSettings{}, fmt.Errorf("unsupported embedding provider %q", name)
}
apiKey, err := resolveAPIKey(name, cfg.APIKeyEnv, validateAPIKey)
if err != nil {
return providerSettings{}, err
}
client := options.httpClient
if client == nil {
client = &http.Client{Timeout: timeout}
}
if _, err := url.ParseRequestURI(baseURL); err != nil {
return providerSettings{}, fmt.Errorf("invalid embeddings base_url %q: %w", baseURL, err)
}
return providerSettings{
Name: name,
Model: model,
BaseURL: baseURL,
APIKey: apiKey,
MaxInputChars: maxInputChars,
Timeout: timeout,
HTTPClient: client,
}, nil
}
func newProvider(settings providerSettings) (Provider, error) {
switch settings.Name {
case ProviderOllama:
return newOllamaProvider(settings), nil
case ProviderOpenAI, ProviderLlamaCpp, ProviderOpenAICompatible:
return newOpenAICompatibleProvider(settings), nil
default:
return nil, fmt.Errorf("unsupported embedding provider %q", settings.Name)
}
}
func resolveAPIKey(provider, apiKeyEnv string, validate bool) (string, error) {
envName := strings.TrimSpace(apiKeyEnv)
required := provider == ProviderOpenAI
if envName == "" {
if required {
envName = "OPENAI_API_KEY"
} else {
return "", nil
}
}
value := strings.TrimSpace(os.Getenv(envName))
if value == "" {
if required || validate {
return "", fmt.Errorf("embedding provider %q requires API key env %s", provider, envName)
}
return "", nil
}
return value, nil
}
func normalizedProviderName(provider string) string {
return strings.ToLower(strings.TrimSpace(provider))
}
func defaultModel(provider string) string {
switch provider {
case ProviderOllama, ProviderLlamaCpp:
return DefaultLocalEmbeddingModel
default:
return DefaultOpenAIModel
}
}
func shouldProbe(settings providerSettings) bool {
switch settings.Name {
case ProviderOllama, ProviderLlamaCpp:
return true
case ProviderOpenAICompatible:
return isLoopbackBaseURL(settings.BaseURL)
default:
return false
}
}
func isLoopbackBaseURL(rawURL string) bool {
parsed, err := url.Parse(rawURL)
if err != nil {
return false
}
host := parsed.Hostname()
if host == "localhost" {
return true
}
ip := net.ParseIP(host)
return ip != nil && ip.IsLoopback()
}
func trimInputs(inputs []string, maxChars int) []string {
if maxChars <= 0 {
maxChars = DefaultMaxInputChars
}
out := make([]string, len(inputs))
for i, input := range inputs {
runes := []rune(input)
if len(runes) > maxChars {
runes = runes[:maxChars]
}
out[i] = string(runes)
}
return out
}
func inferDimensions(vectors [][]float32) (int, error) {
dimensions := 0
for _, vector := range vectors {
if len(vector) == 0 {
return 0, errors.New("embedding response contained an empty vector")
}
if dimensions == 0 {
dimensions = len(vector)
continue
}
if len(vector) != dimensions {
return 0, fmt.Errorf("embedding response dimensions mismatch: got %d want %d", len(vector), dimensions)
}
}
return dimensions, nil
}

View File

@ -0,0 +1,193 @@
package embed
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
"github.com/steipete/discrawl/internal/config"
)
func TestOllamaProviderEmbeds(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/api/embed", r.URL.Path)
require.Equal(t, http.MethodPost, r.Method)
var req ollamaEmbedRequest
require.NoError(t, json.NewDecoder(r.Body).Decode(&req))
require.Equal(t, "nomic-embed-text", req.Model)
require.Equal(t, []string{"abcd", "xy"}, req.Input)
_, _ = w.Write([]byte(`{"model":"nomic-embed-text","embeddings":[[1,2,3],[4,5,6]]}`))
}))
defer server.Close()
provider, err := NewProvider(config.EmbeddingsConfig{
Provider: ProviderOllama,
Model: "nomic-embed-text",
BaseURL: server.URL,
MaxInputChars: 4,
RequestTimeout: "5s",
})
require.NoError(t, err)
batch, err := provider.Embed(context.Background(), []string{"abcdef", "xy"})
require.NoError(t, err)
require.Equal(t, "nomic-embed-text", batch.Model)
require.Equal(t, 3, batch.Dimensions)
require.Equal(t, [][]float32{{1, 2, 3}, {4, 5, 6}}, batch.Vectors)
}
func TestOpenAICompatibleProviderEmbedsAndUsesAuth(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/embeddings", r.URL.Path)
require.Equal(t, "Bearer secret", r.Header.Get("Authorization"))
var req openAIEmbeddingRequest
require.NoError(t, json.NewDecoder(r.Body).Decode(&req))
require.Equal(t, "local-model", req.Model)
require.Equal(t, []string{"one", "two"}, req.Input)
_, _ = w.Write([]byte(`{
"model":"local-model",
"data":[
{"index":1,"embedding":[3,4]},
{"index":0,"embedding":[1,2]}
]
}`))
}))
defer server.Close()
t.Setenv("DISCRAWL_EMBED_KEY", "secret")
provider, err := NewProvider(config.EmbeddingsConfig{
Provider: ProviderOpenAICompatible,
Model: "local-model",
BaseURL: server.URL,
APIKeyEnv: "DISCRAWL_EMBED_KEY",
RequestTimeout: "5s",
})
require.NoError(t, err)
batch, err := provider.Embed(context.Background(), []string{"one", "two"})
require.NoError(t, err)
require.Equal(t, "local-model", batch.Model)
require.Equal(t, 2, batch.Dimensions)
require.Equal(t, [][]float32{{1, 2}, {3, 4}}, batch.Vectors)
}
func TestProviderFactoryDefaultsAndValidation(t *testing.T) {
t.Setenv("OPENAI_API_KEY", "openai-secret")
openAI, err := resolveProviderConfig(config.EmbeddingsConfig{
Provider: ProviderOpenAI,
RequestTimeout: "5s",
}, true)
require.NoError(t, err)
require.Equal(t, DefaultOpenAIBaseURL, openAI.BaseURL)
require.Equal(t, DefaultOpenAIModel, openAI.Model)
require.Equal(t, "openai-secret", openAI.APIKey)
ollama, err := resolveProviderConfig(config.EmbeddingsConfig{
Provider: ProviderOllama,
RequestTimeout: "5s",
}, true)
require.NoError(t, err)
require.Equal(t, DefaultOllamaBaseURL, ollama.BaseURL)
require.Equal(t, DefaultLocalEmbeddingModel, ollama.Model)
llamaCpp, err := resolveProviderConfig(config.EmbeddingsConfig{
Provider: ProviderLlamaCpp,
RequestTimeout: "5s",
}, true)
require.NoError(t, err)
require.Equal(t, DefaultLlamaCppBaseURL, llamaCpp.BaseURL)
_, err = resolveProviderConfig(config.EmbeddingsConfig{
Provider: ProviderOpenAICompatible,
RequestTimeout: "5s",
}, true)
require.ErrorContains(t, err, "requires base_url")
}
func TestProviderFactoryRequiresOpenAIAPIKey(t *testing.T) {
t.Setenv("OPENAI_API_KEY", "")
_, err := NewProvider(config.EmbeddingsConfig{
Provider: ProviderOpenAI,
RequestTimeout: "5s",
})
require.ErrorContains(t, err, "requires API key env OPENAI_API_KEY")
}
func TestProviderFactoryReportsUnsupportedProviderBeforeAPIKey(t *testing.T) {
t.Setenv("MISSING_EMBED_KEY", "")
_, err := NewProvider(config.EmbeddingsConfig{
Provider: "bogus",
APIKeyEnv: "MISSING_EMBED_KEY",
RequestTimeout: "5s",
})
require.ErrorContains(t, err, "unsupported embedding provider \"bogus\"")
}
func TestCheckProviderProbesLocalProvider(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/api/embed", r.URL.Path)
_, _ = w.Write([]byte(`{"model":"nomic-embed-text","embeddings":[[1,2]]}`))
}))
defer server.Close()
result := CheckProvider(context.Background(), config.EmbeddingsConfig{
Provider: ProviderOllama,
Model: "nomic-embed-text",
BaseURL: server.URL,
RequestTimeout: "5s",
})
require.Equal(t, "ok", result.Status)
require.True(t, result.Probed)
require.Empty(t, result.Warning)
require.Equal(t, server.URL, result.BaseURL)
}
func TestCheckProviderWarnsOnLocalProbeFailure(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "not ready", http.StatusServiceUnavailable)
}))
defer server.Close()
result := CheckProvider(context.Background(), config.EmbeddingsConfig{
Provider: ProviderOllama,
Model: "nomic-embed-text",
BaseURL: server.URL,
RequestTimeout: "5s",
})
require.Equal(t, "warning", result.Status)
require.Contains(t, result.Warning, "HTTP 503")
require.False(t, result.Probed)
}
func TestProviderRejectsInvalidResponses(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(`{"data":[{"index":0,"embedding":[1]},{"index":1,"embedding":[2,3]}]}`))
}))
defer server.Close()
provider, err := NewProvider(config.EmbeddingsConfig{
Provider: ProviderOpenAICompatible,
Model: "local-model",
BaseURL: server.URL,
RequestTimeout: "5s",
})
require.NoError(t, err)
_, err = provider.Embed(context.Background(), []string{"one", "two"})
require.ErrorContains(t, err, "dimensions mismatch")
}