refactor: consume crawlkit embedding primitives
This commit is contained in:
parent
40317aa538
commit
40c787c54a
@ -29,6 +29,7 @@
|
|||||||
### Maintenance
|
### Maintenance
|
||||||
|
|
||||||
- Migrated runtime paths, SQLite opening, archive mirror/export/import helpers, output/status wiring, and TUI plumbing onto the shared `crawlkit` infrastructure.
|
- Migrated runtime paths, SQLite opening, archive mirror/export/import helpers, output/status wiring, and TUI plumbing onto the shared `crawlkit` infrastructure.
|
||||||
|
- Moved reusable embedding providers and vector helpers onto `crawlkit` while keeping Discrawl-owned storage, FTS, queueing, and privacy filters local.
|
||||||
- Updated crawlkit through `v0.4.1`, switched imports to `github.com/openclaw/crawlkit`, and added CI smoke coverage for the crawlkit control surface and merge behavior.
|
- Updated crawlkit through `v0.4.1`, switched imports to `github.com/openclaw/crawlkit`, and added CI smoke coverage for the crawlkit control surface and merge behavior.
|
||||||
- Added CodeQL, verified secret scanning, protected automation owners, stale issue automation, `.editorconfig`, and `.gitattributes`.
|
- Added CodeQL, verified secret scanning, protected automation owners, stale issue automation, `.editorconfig`, and `.gitattributes`.
|
||||||
- Added release workflow automation that dispatches the Homebrew tap formula update after GoReleaser publishes a tag.
|
- Added release workflow automation that dispatches the Homebrew tap formula update after GoReleaser publishes a tag.
|
||||||
|
|||||||
2
go.mod
2
go.mod
@ -43,7 +43,7 @@ require (
|
|||||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||||
github.com/muesli/termenv v0.16.0 // indirect
|
github.com/muesli/termenv v0.16.0 // indirect
|
||||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||||
github.com/openclaw/crawlkit v0.4.2
|
github.com/openclaw/crawlkit v0.5.0
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||||
github.com/rivo/uniseg v0.4.7 // indirect
|
github.com/rivo/uniseg v0.4.7 // indirect
|
||||||
|
|||||||
4
go.sum
4
go.sum
@ -63,8 +63,8 @@ github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc
|
|||||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||||
github.com/openclaw/crawlkit v0.4.2 h1:Lzzkd2/xSkQk+7KyboMEw+ZS2wmlYvDFLwAB2Z/FwBs=
|
github.com/openclaw/crawlkit v0.5.0 h1:sVqIbQ5v6LiOf+NXcVj93UhfoaJqMbBlrd1lU6uhO9M=
|
||||||
github.com/openclaw/crawlkit v0.4.2/go.mod h1:/AI8o/DeRqXPZJPHq/9mGUjNzLPskm/wTjikRPxEdHY=
|
github.com/openclaw/crawlkit v0.5.0/go.mod h1:/AI8o/DeRqXPZJPHq/9mGUjNzLPskm/wTjikRPxEdHY=
|
||||||
github.com/pelletier/go-toml/v2 v2.3.1 h1:MYEvvGnQjeNkRF1qUuGolNtNExTDwct51yp7olPtrEc=
|
github.com/pelletier/go-toml/v2 v2.3.1 h1:MYEvvGnQjeNkRF1qUuGolNtNExTDwct51yp7olPtrEc=
|
||||||
github.com/pelletier/go-toml/v2 v2.3.1/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
github.com/pelletier/go-toml/v2 v2.3.1/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||||
|
|||||||
@ -13,10 +13,10 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/openclaw/crawlkit/embed"
|
||||||
"github.com/openclaw/discrawl/internal/config"
|
"github.com/openclaw/discrawl/internal/config"
|
||||||
"github.com/openclaw/discrawl/internal/discord"
|
"github.com/openclaw/discrawl/internal/discord"
|
||||||
"github.com/openclaw/discrawl/internal/discorddesktop"
|
"github.com/openclaw/discrawl/internal/discorddesktop"
|
||||||
"github.com/openclaw/discrawl/internal/embed"
|
|
||||||
"github.com/openclaw/discrawl/internal/share"
|
"github.com/openclaw/discrawl/internal/share"
|
||||||
"github.com/openclaw/discrawl/internal/store"
|
"github.com/openclaw/discrawl/internal/store"
|
||||||
"github.com/openclaw/discrawl/internal/syncer"
|
"github.com/openclaw/discrawl/internal/syncer"
|
||||||
@ -374,7 +374,7 @@ func (r *runtime) runEmbed(args []string) error {
|
|||||||
providerFactory := r.newEmbed
|
providerFactory := r.newEmbed
|
||||||
if providerFactory == nil {
|
if providerFactory == nil {
|
||||||
providerFactory = func(cfg config.EmbeddingsConfig) (embed.Provider, error) {
|
providerFactory = func(cfg config.EmbeddingsConfig) (embed.Provider, error) {
|
||||||
return embed.NewProvider(cfg)
|
return embed.NewProvider(crawlkitEmbeddingConfig(cfg))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
provider, err := providerFactory(r.cfg.Search.Embeddings)
|
provider, err := providerFactory(r.cfg.Search.Embeddings)
|
||||||
@ -435,7 +435,7 @@ func (r *runtime) runDoctor(args []string) error {
|
|||||||
report["share_stale_after"] = cfg.Share.StaleAfter
|
report["share_stale_after"] = cfg.Share.StaleAfter
|
||||||
}
|
}
|
||||||
if cfg.Search.Embeddings.Enabled {
|
if cfg.Search.Embeddings.Enabled {
|
||||||
check := embed.CheckProvider(r.ctx, cfg.Search.Embeddings)
|
check := embed.CheckProvider(r.ctx, crawlkitEmbeddingConfig(cfg.Search.Embeddings))
|
||||||
report["embeddings"] = check.Status
|
report["embeddings"] = check.Status
|
||||||
report["embeddings_provider"] = check.Provider
|
report["embeddings_provider"] = check.Provider
|
||||||
report["embeddings_model"] = check.Model
|
report["embeddings_model"] = check.Model
|
||||||
|
|||||||
@ -11,9 +11,9 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/bwmarrin/discordgo"
|
"github.com/bwmarrin/discordgo"
|
||||||
|
"github.com/openclaw/crawlkit/embed"
|
||||||
"github.com/openclaw/discrawl/internal/config"
|
"github.com/openclaw/discrawl/internal/config"
|
||||||
"github.com/openclaw/discrawl/internal/discord"
|
"github.com/openclaw/discrawl/internal/discord"
|
||||||
"github.com/openclaw/discrawl/internal/embed"
|
|
||||||
"github.com/openclaw/discrawl/internal/share"
|
"github.com/openclaw/discrawl/internal/share"
|
||||||
"github.com/openclaw/discrawl/internal/store"
|
"github.com/openclaw/discrawl/internal/store"
|
||||||
"github.com/openclaw/discrawl/internal/syncer"
|
"github.com/openclaw/discrawl/internal/syncer"
|
||||||
@ -118,6 +118,17 @@ type runtime struct {
|
|||||||
now func() time.Time
|
now func() time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func crawlkitEmbeddingConfig(cfg config.EmbeddingsConfig) embed.Config {
|
||||||
|
return embed.Config{
|
||||||
|
Provider: cfg.Provider,
|
||||||
|
Model: cfg.Model,
|
||||||
|
BaseURL: cfg.BaseURL,
|
||||||
|
APIKeyEnv: cfg.APIKeyEnv,
|
||||||
|
RequestTimeout: cfg.RequestTimeout,
|
||||||
|
MaxInputChars: cfg.MaxInputChars,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type discordClient interface {
|
type discordClient interface {
|
||||||
syncer.Client
|
syncer.Client
|
||||||
Close() error
|
Close() error
|
||||||
|
|||||||
@ -9,8 +9,8 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/openclaw/crawlkit/embed"
|
||||||
"github.com/openclaw/discrawl/internal/config"
|
"github.com/openclaw/discrawl/internal/config"
|
||||||
"github.com/openclaw/discrawl/internal/embed"
|
|
||||||
"github.com/openclaw/discrawl/internal/store"
|
"github.com/openclaw/discrawl/internal/store"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -112,7 +112,7 @@ func (r *runtime) semanticSearchOptions(opts store.SearchOptions) (store.Semanti
|
|||||||
providerFactory := r.newEmbed
|
providerFactory := r.newEmbed
|
||||||
if providerFactory == nil {
|
if providerFactory == nil {
|
||||||
providerFactory = func(cfg config.EmbeddingsConfig) (embed.Provider, error) {
|
providerFactory = func(cfg config.EmbeddingsConfig) (embed.Provider, error) {
|
||||||
return embed.NewProvider(cfg)
|
return embed.NewProvider(crawlkitEmbeddingConfig(cfg))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
provider, err := providerFactory(r.cfg.Search.Embeddings)
|
provider, err := providerFactory(r.cfg.Search.Embeddings)
|
||||||
|
|||||||
@ -1,91 +0,0 @@
|
|||||||
package embed
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ollamaProvider struct {
|
|
||||||
client *http.Client
|
|
||||||
baseURL string
|
|
||||||
model string
|
|
||||||
maxInputChars int
|
|
||||||
}
|
|
||||||
|
|
||||||
type ollamaEmbedRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Input []string `json:"input"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ollamaEmbedResponse struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Embeddings [][]float32 `json:"embeddings"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func newOllamaProvider(settings providerSettings) Provider {
|
|
||||||
return &ollamaProvider{
|
|
||||||
client: settings.HTTPClient,
|
|
||||||
baseURL: settings.BaseURL,
|
|
||||||
model: settings.Model,
|
|
||||||
maxInputChars: settings.MaxInputChars,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ollamaProvider) Embed(ctx context.Context, inputs []string) (EmbeddingBatch, error) {
|
|
||||||
if len(inputs) == 0 {
|
|
||||||
return EmbeddingBatch{Model: p.model}, nil
|
|
||||||
}
|
|
||||||
payload := ollamaEmbedRequest{
|
|
||||||
Model: p.model,
|
|
||||||
Input: trimInputs(inputs, p.maxInputChars),
|
|
||||||
}
|
|
||||||
var response ollamaEmbedResponse
|
|
||||||
if err := postJSON(ctx, p.client, p.baseURL+"/api/embed", "", payload, &response); err != nil {
|
|
||||||
return EmbeddingBatch{}, err
|
|
||||||
}
|
|
||||||
if len(response.Embeddings) != len(inputs) {
|
|
||||||
return EmbeddingBatch{}, fmt.Errorf("ollama embedding response returned %d vectors for %d inputs", len(response.Embeddings), len(inputs))
|
|
||||||
}
|
|
||||||
dimensions, err := inferDimensions(response.Embeddings)
|
|
||||||
if err != nil {
|
|
||||||
return EmbeddingBatch{}, err
|
|
||||||
}
|
|
||||||
model := response.Model
|
|
||||||
if model == "" {
|
|
||||||
model = p.model
|
|
||||||
}
|
|
||||||
return EmbeddingBatch{Model: model, Dimensions: dimensions, Vectors: response.Embeddings}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func postJSON(ctx context.Context, client *http.Client, endpoint, apiKey string, payload any, target any) error {
|
|
||||||
body, err := json.Marshal(payload)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("marshal embedding request: %w", err)
|
|
||||||
}
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("build embedding request: %w", err)
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Accept", "application/json")
|
|
||||||
if apiKey != "" {
|
|
||||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
||||||
}
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("embedding request failed: %w", err)
|
|
||||||
}
|
|
||||||
defer func() { _ = resp.Body.Close() }()
|
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
||||||
msg, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
|
||||||
return &HTTPError{StatusCode: resp.StatusCode, Body: string(msg)}
|
|
||||||
}
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(target); err != nil {
|
|
||||||
return fmt.Errorf("decode embedding response: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@ -1,82 +0,0 @@
|
|||||||
package embed
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
)
|
|
||||||
|
|
||||||
type openAICompatibleProvider struct {
|
|
||||||
client *http.Client
|
|
||||||
baseURL string
|
|
||||||
apiKey string
|
|
||||||
model string
|
|
||||||
maxInputChars int
|
|
||||||
}
|
|
||||||
|
|
||||||
type openAIEmbeddingRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Input []string `json:"input"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type openAIEmbeddingResponse struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Data []openAIEmbeddingItem `json:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type openAIEmbeddingItem struct {
|
|
||||||
Index *int `json:"index"`
|
|
||||||
Embedding []float32 `json:"embedding"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func newOpenAICompatibleProvider(settings providerSettings) Provider {
|
|
||||||
return &openAICompatibleProvider{
|
|
||||||
client: settings.HTTPClient,
|
|
||||||
baseURL: settings.BaseURL,
|
|
||||||
apiKey: settings.APIKey,
|
|
||||||
model: settings.Model,
|
|
||||||
maxInputChars: settings.MaxInputChars,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *openAICompatibleProvider) Embed(ctx context.Context, inputs []string) (EmbeddingBatch, error) {
|
|
||||||
if len(inputs) == 0 {
|
|
||||||
return EmbeddingBatch{Model: p.model}, nil
|
|
||||||
}
|
|
||||||
payload := openAIEmbeddingRequest{
|
|
||||||
Model: p.model,
|
|
||||||
Input: trimInputs(inputs, p.maxInputChars),
|
|
||||||
}
|
|
||||||
var response openAIEmbeddingResponse
|
|
||||||
if err := postJSON(ctx, p.client, p.baseURL+"/embeddings", p.apiKey, payload, &response); err != nil {
|
|
||||||
return EmbeddingBatch{}, err
|
|
||||||
}
|
|
||||||
if len(response.Data) != len(inputs) {
|
|
||||||
return EmbeddingBatch{}, fmt.Errorf("openai-compatible embedding response returned %d vectors for %d inputs", len(response.Data), len(inputs))
|
|
||||||
}
|
|
||||||
vectors := make([][]float32, len(inputs))
|
|
||||||
seen := make([]bool, len(inputs))
|
|
||||||
for position, item := range response.Data {
|
|
||||||
index := position
|
|
||||||
if item.Index != nil {
|
|
||||||
index = *item.Index
|
|
||||||
}
|
|
||||||
if index < 0 || index >= len(inputs) {
|
|
||||||
return EmbeddingBatch{}, fmt.Errorf("openai-compatible embedding response index %d out of range", index)
|
|
||||||
}
|
|
||||||
if seen[index] {
|
|
||||||
return EmbeddingBatch{}, fmt.Errorf("openai-compatible embedding response duplicated index %d", index)
|
|
||||||
}
|
|
||||||
seen[index] = true
|
|
||||||
vectors[index] = item.Embedding
|
|
||||||
}
|
|
||||||
dimensions, err := inferDimensions(vectors)
|
|
||||||
if err != nil {
|
|
||||||
return EmbeddingBatch{}, err
|
|
||||||
}
|
|
||||||
model := response.Model
|
|
||||||
if model == "" {
|
|
||||||
model = p.model
|
|
||||||
}
|
|
||||||
return EmbeddingBatch{Model: model, Dimensions: dimensions, Vectors: vectors}, nil
|
|
||||||
}
|
|
||||||
@ -1,310 +0,0 @@
|
|||||||
package embed
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/openclaw/discrawl/internal/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
ProviderOpenAI = "openai"
|
|
||||||
ProviderOllama = "ollama"
|
|
||||||
ProviderLlamaCpp = "llamacpp"
|
|
||||||
ProviderOpenAICompatible = "openai_compatible"
|
|
||||||
DefaultOpenAIBaseURL = "https://api.openai.com/v1"
|
|
||||||
DefaultOllamaBaseURL = "http://127.0.0.1:11434"
|
|
||||||
DefaultLlamaCppBaseURL = "http://127.0.0.1:8080/v1"
|
|
||||||
DefaultOpenAIModel = "text-embedding-3-small"
|
|
||||||
DefaultLocalEmbeddingModel = "nomic-embed-text"
|
|
||||||
DefaultBatchSize = 64
|
|
||||||
DefaultMaxInputChars = 12000
|
|
||||||
DefaultRequestTimeout = 2 * time.Minute
|
|
||||||
DefaultProbeTimeout = 2 * time.Second
|
|
||||||
)
|
|
||||||
|
|
||||||
// Provider is the narrow embedding surface used by later queue/search work.
|
|
||||||
type Provider interface {
|
|
||||||
Embed(ctx context.Context, inputs []string) (EmbeddingBatch, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type EmbeddingBatch struct {
|
|
||||||
Model string
|
|
||||||
Dimensions int
|
|
||||||
Vectors [][]float32
|
|
||||||
}
|
|
||||||
|
|
||||||
type HTTPError struct {
|
|
||||||
StatusCode int
|
|
||||||
Body string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *HTTPError) Error() string {
|
|
||||||
return fmt.Sprintf("embedding request failed with HTTP %d: %s", e.StatusCode, e.Body)
|
|
||||||
}
|
|
||||||
|
|
||||||
func IsRateLimitError(err error) bool {
|
|
||||||
var httpErr *HTTPError
|
|
||||||
return errors.As(err, &httpErr) && httpErr.StatusCode == http.StatusTooManyRequests
|
|
||||||
}
|
|
||||||
|
|
||||||
type CheckResult struct {
|
|
||||||
Provider string
|
|
||||||
Model string
|
|
||||||
BaseURL string
|
|
||||||
Status string
|
|
||||||
Warning string
|
|
||||||
Probed bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type Option func(*providerOptions)
|
|
||||||
|
|
||||||
type providerOptions struct {
|
|
||||||
httpClient *http.Client
|
|
||||||
timeoutOverride time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
type providerSettings struct {
|
|
||||||
Name string
|
|
||||||
Model string
|
|
||||||
BaseURL string
|
|
||||||
APIKey string
|
|
||||||
MaxInputChars int
|
|
||||||
Timeout time.Duration
|
|
||||||
HTTPClient *http.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithHTTPClient(client *http.Client) Option {
|
|
||||||
return func(opts *providerOptions) {
|
|
||||||
opts.httpClient = client
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithRequestTimeout(timeout time.Duration) Option {
|
|
||||||
return func(opts *providerOptions) {
|
|
||||||
opts.timeoutOverride = timeout
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewProvider(cfg config.EmbeddingsConfig, opts ...Option) (Provider, error) {
|
|
||||||
settings, err := resolveProviderConfig(cfg, true, opts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return newProvider(settings)
|
|
||||||
}
|
|
||||||
|
|
||||||
func CheckProvider(ctx context.Context, cfg config.EmbeddingsConfig) CheckResult {
|
|
||||||
settings, err := resolveProviderConfig(cfg, true, WithRequestTimeout(DefaultProbeTimeout))
|
|
||||||
if err != nil {
|
|
||||||
return CheckResult{
|
|
||||||
Provider: normalizedProviderName(cfg.Provider),
|
|
||||||
Model: strings.TrimSpace(cfg.Model),
|
|
||||||
BaseURL: strings.TrimSpace(cfg.BaseURL),
|
|
||||||
Status: "warning",
|
|
||||||
Warning: err.Error(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result := CheckResult{
|
|
||||||
Provider: settings.Name,
|
|
||||||
Model: settings.Model,
|
|
||||||
BaseURL: settings.BaseURL,
|
|
||||||
Status: "ok",
|
|
||||||
}
|
|
||||||
if !shouldProbe(settings) {
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
provider, err := newProvider(settings)
|
|
||||||
if err != nil {
|
|
||||||
result.Status = "warning"
|
|
||||||
result.Warning = err.Error()
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
probeCtx, cancel := context.WithTimeout(ctx, DefaultProbeTimeout)
|
|
||||||
defer cancel()
|
|
||||||
if _, err := provider.Embed(probeCtx, []string{"discrawl probe"}); err != nil {
|
|
||||||
result.Status = "warning"
|
|
||||||
result.Warning = err.Error()
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
result.Probed = true
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func resolveProviderConfig(cfg config.EmbeddingsConfig, validateAPIKey bool, opts ...Option) (providerSettings, error) {
|
|
||||||
options := providerOptions{}
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(&options)
|
|
||||||
}
|
|
||||||
name := normalizedProviderName(cfg.Provider)
|
|
||||||
if name == "" {
|
|
||||||
name = ProviderOpenAI
|
|
||||||
}
|
|
||||||
model := strings.TrimSpace(cfg.Model)
|
|
||||||
if model == "" {
|
|
||||||
model = defaultModel(name)
|
|
||||||
}
|
|
||||||
baseURL := strings.TrimRight(strings.TrimSpace(cfg.BaseURL), "/")
|
|
||||||
if baseURL == "" {
|
|
||||||
switch name {
|
|
||||||
case ProviderOpenAI:
|
|
||||||
baseURL = DefaultOpenAIBaseURL
|
|
||||||
case ProviderOllama:
|
|
||||||
baseURL = DefaultOllamaBaseURL
|
|
||||||
case ProviderLlamaCpp:
|
|
||||||
baseURL = DefaultLlamaCppBaseURL
|
|
||||||
case ProviderOpenAICompatible:
|
|
||||||
return providerSettings{}, fmt.Errorf("embedding provider %q requires base_url", name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
timeout := DefaultRequestTimeout
|
|
||||||
if strings.TrimSpace(cfg.RequestTimeout) != "" {
|
|
||||||
parsed, err := time.ParseDuration(cfg.RequestTimeout)
|
|
||||||
if err != nil {
|
|
||||||
return providerSettings{}, fmt.Errorf("parse embeddings request_timeout: %w", err)
|
|
||||||
}
|
|
||||||
if parsed <= 0 {
|
|
||||||
return providerSettings{}, errors.New("embeddings request_timeout must be positive")
|
|
||||||
}
|
|
||||||
timeout = parsed
|
|
||||||
}
|
|
||||||
if options.timeoutOverride > 0 && options.timeoutOverride < timeout {
|
|
||||||
timeout = options.timeoutOverride
|
|
||||||
}
|
|
||||||
maxInputChars := cfg.MaxInputChars
|
|
||||||
if maxInputChars <= 0 {
|
|
||||||
maxInputChars = DefaultMaxInputChars
|
|
||||||
}
|
|
||||||
switch name {
|
|
||||||
case ProviderOpenAI, ProviderOllama, ProviderLlamaCpp, ProviderOpenAICompatible:
|
|
||||||
default:
|
|
||||||
return providerSettings{}, fmt.Errorf("unsupported embedding provider %q", name)
|
|
||||||
}
|
|
||||||
apiKey, err := resolveAPIKey(name, cfg.APIKeyEnv, validateAPIKey)
|
|
||||||
if err != nil {
|
|
||||||
return providerSettings{}, err
|
|
||||||
}
|
|
||||||
client := options.httpClient
|
|
||||||
if client == nil {
|
|
||||||
client = &http.Client{Timeout: timeout}
|
|
||||||
}
|
|
||||||
if _, err := url.ParseRequestURI(baseURL); err != nil {
|
|
||||||
return providerSettings{}, fmt.Errorf("invalid embeddings base_url %q: %w", baseURL, err)
|
|
||||||
}
|
|
||||||
return providerSettings{
|
|
||||||
Name: name,
|
|
||||||
Model: model,
|
|
||||||
BaseURL: baseURL,
|
|
||||||
APIKey: apiKey,
|
|
||||||
MaxInputChars: maxInputChars,
|
|
||||||
Timeout: timeout,
|
|
||||||
HTTPClient: client,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newProvider(settings providerSettings) (Provider, error) {
|
|
||||||
switch settings.Name {
|
|
||||||
case ProviderOllama:
|
|
||||||
return newOllamaProvider(settings), nil
|
|
||||||
case ProviderOpenAI, ProviderLlamaCpp, ProviderOpenAICompatible:
|
|
||||||
return newOpenAICompatibleProvider(settings), nil
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unsupported embedding provider %q", settings.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func resolveAPIKey(provider, apiKeyEnv string, validate bool) (string, error) {
|
|
||||||
envName := strings.TrimSpace(apiKeyEnv)
|
|
||||||
required := provider == ProviderOpenAI
|
|
||||||
if envName == "" {
|
|
||||||
if required {
|
|
||||||
envName = "OPENAI_API_KEY"
|
|
||||||
} else {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
value := strings.TrimSpace(os.Getenv(envName))
|
|
||||||
if value == "" {
|
|
||||||
if required || validate {
|
|
||||||
return "", fmt.Errorf("embedding provider %q requires API key env %s", provider, envName)
|
|
||||||
}
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
return value, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func normalizedProviderName(provider string) string {
|
|
||||||
return strings.ToLower(strings.TrimSpace(provider))
|
|
||||||
}
|
|
||||||
|
|
||||||
func defaultModel(provider string) string {
|
|
||||||
switch provider {
|
|
||||||
case ProviderOllama, ProviderLlamaCpp:
|
|
||||||
return DefaultLocalEmbeddingModel
|
|
||||||
default:
|
|
||||||
return DefaultOpenAIModel
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func shouldProbe(settings providerSettings) bool {
|
|
||||||
switch settings.Name {
|
|
||||||
case ProviderOllama, ProviderLlamaCpp:
|
|
||||||
return true
|
|
||||||
case ProviderOpenAICompatible:
|
|
||||||
return isLoopbackBaseURL(settings.BaseURL)
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func isLoopbackBaseURL(rawURL string) bool {
|
|
||||||
parsed, err := url.Parse(rawURL)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
host := parsed.Hostname()
|
|
||||||
if host == "localhost" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
ip := net.ParseIP(host)
|
|
||||||
return ip != nil && ip.IsLoopback()
|
|
||||||
}
|
|
||||||
|
|
||||||
func trimInputs(inputs []string, maxChars int) []string {
|
|
||||||
if maxChars <= 0 {
|
|
||||||
maxChars = DefaultMaxInputChars
|
|
||||||
}
|
|
||||||
out := make([]string, len(inputs))
|
|
||||||
for i, input := range inputs {
|
|
||||||
runes := []rune(input)
|
|
||||||
if len(runes) > maxChars {
|
|
||||||
runes = runes[:maxChars]
|
|
||||||
}
|
|
||||||
out[i] = string(runes)
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func inferDimensions(vectors [][]float32) (int, error) {
|
|
||||||
dimensions := 0
|
|
||||||
for _, vector := range vectors {
|
|
||||||
if len(vector) == 0 {
|
|
||||||
return 0, errors.New("embedding response contained an empty vector")
|
|
||||||
}
|
|
||||||
if dimensions == 0 {
|
|
||||||
dimensions = len(vector)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if len(vector) != dimensions {
|
|
||||||
return 0, fmt.Errorf("embedding response dimensions mismatch: got %d want %d", len(vector), dimensions)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return dimensions, nil
|
|
||||||
}
|
|
||||||
@ -1,387 +0,0 @@
|
|||||||
package embed
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/openclaw/discrawl/internal/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestOllamaProviderEmbeds(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
assert.Equal(t, "/api/embed", r.URL.Path)
|
|
||||||
assert.Equal(t, http.MethodPost, r.Method)
|
|
||||||
var req ollamaEmbedRequest
|
|
||||||
assert.NoError(t, json.NewDecoder(r.Body).Decode(&req))
|
|
||||||
assert.Equal(t, "nomic-embed-text", req.Model)
|
|
||||||
assert.Equal(t, []string{"abcd", "xy"}, req.Input)
|
|
||||||
_, _ = w.Write([]byte(`{"model":"nomic-embed-text","embeddings":[[1,2,3],[4,5,6]]}`))
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
provider, err := NewProvider(config.EmbeddingsConfig{
|
|
||||||
Provider: ProviderOllama,
|
|
||||||
Model: "nomic-embed-text",
|
|
||||||
BaseURL: server.URL,
|
|
||||||
MaxInputChars: 4,
|
|
||||||
RequestTimeout: "5s",
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
batch, err := provider.Embed(context.Background(), []string{"abcdef", "xy"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, "nomic-embed-text", batch.Model)
|
|
||||||
require.Equal(t, 3, batch.Dimensions)
|
|
||||||
require.Equal(t, [][]float32{{1, 2, 3}, {4, 5, 6}}, batch.Vectors)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenAICompatibleProviderEmbedsAndUsesAuth(t *testing.T) {
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
assert.Equal(t, "/embeddings", r.URL.Path)
|
|
||||||
assert.Equal(t, "Bearer secret", r.Header.Get("Authorization"))
|
|
||||||
var req openAIEmbeddingRequest
|
|
||||||
assert.NoError(t, json.NewDecoder(r.Body).Decode(&req))
|
|
||||||
assert.Equal(t, "local-model", req.Model)
|
|
||||||
assert.Equal(t, []string{"one", "two"}, req.Input)
|
|
||||||
_, _ = w.Write([]byte(`{
|
|
||||||
"model":"local-model",
|
|
||||||
"data":[
|
|
||||||
{"index":1,"embedding":[3,4]},
|
|
||||||
{"index":0,"embedding":[1,2]}
|
|
||||||
]
|
|
||||||
}`))
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
t.Setenv("DISCRAWL_EMBED_KEY", "secret")
|
|
||||||
|
|
||||||
provider, err := NewProvider(config.EmbeddingsConfig{
|
|
||||||
Provider: ProviderOpenAICompatible,
|
|
||||||
Model: "local-model",
|
|
||||||
BaseURL: server.URL,
|
|
||||||
APIKeyEnv: "DISCRAWL_EMBED_KEY",
|
|
||||||
RequestTimeout: "5s",
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
batch, err := provider.Embed(context.Background(), []string{"one", "two"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, "local-model", batch.Model)
|
|
||||||
require.Equal(t, 2, batch.Dimensions)
|
|
||||||
require.Equal(t, [][]float32{{1, 2}, {3, 4}}, batch.Vectors)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProviderFactoryDefaultsAndValidation(t *testing.T) {
|
|
||||||
t.Setenv("OPENAI_API_KEY", "openai-secret")
|
|
||||||
|
|
||||||
openAI, err := resolveProviderConfig(config.EmbeddingsConfig{
|
|
||||||
Provider: ProviderOpenAI,
|
|
||||||
RequestTimeout: "5s",
|
|
||||||
}, true)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, DefaultOpenAIBaseURL, openAI.BaseURL)
|
|
||||||
require.Equal(t, DefaultOpenAIModel, openAI.Model)
|
|
||||||
require.Equal(t, "openai-secret", openAI.APIKey)
|
|
||||||
|
|
||||||
ollama, err := resolveProviderConfig(config.EmbeddingsConfig{
|
|
||||||
Provider: ProviderOllama,
|
|
||||||
RequestTimeout: "5s",
|
|
||||||
}, true)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, DefaultOllamaBaseURL, ollama.BaseURL)
|
|
||||||
require.Equal(t, DefaultLocalEmbeddingModel, ollama.Model)
|
|
||||||
|
|
||||||
llamaCpp, err := resolveProviderConfig(config.EmbeddingsConfig{
|
|
||||||
Provider: ProviderLlamaCpp,
|
|
||||||
RequestTimeout: "5s",
|
|
||||||
}, true)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, DefaultLlamaCppBaseURL, llamaCpp.BaseURL)
|
|
||||||
|
|
||||||
_, err = resolveProviderConfig(config.EmbeddingsConfig{
|
|
||||||
Provider: ProviderOpenAICompatible,
|
|
||||||
RequestTimeout: "5s",
|
|
||||||
}, true)
|
|
||||||
require.ErrorContains(t, err, "requires base_url")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProviderFactoryRequiresOpenAIAPIKey(t *testing.T) {
|
|
||||||
t.Setenv("OPENAI_API_KEY", "")
|
|
||||||
|
|
||||||
_, err := NewProvider(config.EmbeddingsConfig{
|
|
||||||
Provider: ProviderOpenAI,
|
|
||||||
RequestTimeout: "5s",
|
|
||||||
})
|
|
||||||
require.ErrorContains(t, err, "requires API key env OPENAI_API_KEY")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProviderFactoryReportsUnsupportedProviderBeforeAPIKey(t *testing.T) {
|
|
||||||
t.Setenv("MISSING_EMBED_KEY", "")
|
|
||||||
|
|
||||||
_, err := NewProvider(config.EmbeddingsConfig{
|
|
||||||
Provider: "bogus",
|
|
||||||
APIKeyEnv: "MISSING_EMBED_KEY",
|
|
||||||
RequestTimeout: "5s",
|
|
||||||
})
|
|
||||||
require.ErrorContains(t, err, "unsupported embedding provider \"bogus\"")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCheckProviderProbesLocalProvider(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
assert.Equal(t, "/api/embed", r.URL.Path)
|
|
||||||
_, _ = w.Write([]byte(`{"model":"nomic-embed-text","embeddings":[[1,2]]}`))
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
result := CheckProvider(context.Background(), config.EmbeddingsConfig{
|
|
||||||
Provider: ProviderOllama,
|
|
||||||
Model: "nomic-embed-text",
|
|
||||||
BaseURL: server.URL,
|
|
||||||
RequestTimeout: "5s",
|
|
||||||
})
|
|
||||||
require.Equal(t, "ok", result.Status)
|
|
||||||
require.True(t, result.Probed)
|
|
||||||
require.Empty(t, result.Warning)
|
|
||||||
require.Equal(t, server.URL, result.BaseURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCheckProviderWarnsOnLocalProbeFailure(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
http.Error(w, "not ready", http.StatusServiceUnavailable)
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
result := CheckProvider(context.Background(), config.EmbeddingsConfig{
|
|
||||||
Provider: ProviderOllama,
|
|
||||||
Model: "nomic-embed-text",
|
|
||||||
BaseURL: server.URL,
|
|
||||||
RequestTimeout: "5s",
|
|
||||||
})
|
|
||||||
require.Equal(t, "warning", result.Status)
|
|
||||||
require.Contains(t, result.Warning, "HTTP 503")
|
|
||||||
require.False(t, result.Probed)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProviderExposesRateLimitErrors(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
http.Error(w, "rate limited", http.StatusTooManyRequests)
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
provider, err := NewProvider(config.EmbeddingsConfig{
|
|
||||||
Provider: ProviderOpenAICompatible,
|
|
||||||
Model: "local-model",
|
|
||||||
BaseURL: server.URL,
|
|
||||||
RequestTimeout: "5s",
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
_, err = provider.Embed(context.Background(), []string{"one"})
|
|
||||||
require.ErrorContains(t, err, "HTTP 429")
|
|
||||||
require.True(t, IsRateLimitError(err))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProviderRejectsInvalidResponses(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
_, _ = w.Write([]byte(`{"data":[{"index":0,"embedding":[1]},{"index":1,"embedding":[2,3]}]}`))
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
provider, err := NewProvider(config.EmbeddingsConfig{
|
|
||||||
Provider: ProviderOpenAICompatible,
|
|
||||||
Model: "local-model",
|
|
||||||
BaseURL: server.URL,
|
|
||||||
RequestTimeout: "5s",
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
_, err = provider.Embed(context.Background(), []string{"one", "two"})
|
|
||||||
require.ErrorContains(t, err, "dimensions mismatch")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEmbeddingProvidersHandleEmptyInputsAndIndexErrors(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
settings := providerSettings{
|
|
||||||
Name: ProviderOllama,
|
|
||||||
Model: "model",
|
|
||||||
BaseURL: "http://127.0.0.1:1",
|
|
||||||
MaxInputChars: 10,
|
|
||||||
HTTPClient: http.DefaultClient,
|
|
||||||
}
|
|
||||||
ollama := newOllamaProvider(settings)
|
|
||||||
batch, err := ollama.Embed(context.Background(), nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, "model", batch.Model)
|
|
||||||
|
|
||||||
settings.Name = ProviderOpenAICompatible
|
|
||||||
openai := newOpenAICompatibleProvider(settings)
|
|
||||||
batch, err = openai.Embed(context.Background(), nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, "model", batch.Model)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
body string
|
|
||||||
inputs []string
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "count",
|
|
||||||
body: `{"data":[]}`,
|
|
||||||
inputs: []string{"one"},
|
|
||||||
want: "returned 0 vectors for 1 inputs",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "range",
|
|
||||||
body: `{"data":[{"index":2,"embedding":[1]}]}`,
|
|
||||||
inputs: []string{"one"},
|
|
||||||
want: "index 2 out of range",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "duplicate",
|
|
||||||
body: `{"data":[{"index":0,"embedding":[1]},{"index":0,"embedding":[2]}]}`,
|
|
||||||
inputs: []string{"one", "two"},
|
|
||||||
want: "duplicated index 0",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
_, _ = w.Write([]byte(tc.body))
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
provider, err := NewProvider(config.EmbeddingsConfig{
|
|
||||||
Provider: ProviderOpenAICompatible,
|
|
||||||
Model: "model",
|
|
||||||
BaseURL: server.URL,
|
|
||||||
RequestTimeout: "5s",
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
_, err = provider.Embed(context.Background(), tc.inputs)
|
|
||||||
require.ErrorContains(t, err, tc.want)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProviderOptionsAndProbeDecisions(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
client := &http.Client{Timeout: time.Second}
|
|
||||||
settings, err := resolveProviderConfig(config.EmbeddingsConfig{
|
|
||||||
Provider: ProviderOllama,
|
|
||||||
BaseURL: "http://127.0.0.1:11434/",
|
|
||||||
RequestTimeout: "30s",
|
|
||||||
}, true, WithHTTPClient(client), WithRequestTimeout(50*time.Millisecond))
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Same(t, client, settings.HTTPClient)
|
|
||||||
require.Equal(t, 50*time.Millisecond, settings.Timeout)
|
|
||||||
require.Equal(t, "http://127.0.0.1:11434", settings.BaseURL)
|
|
||||||
require.True(t, shouldProbe(settings))
|
|
||||||
|
|
||||||
require.True(t, isLoopbackBaseURL("http://localhost:8080/v1"))
|
|
||||||
require.True(t, isLoopbackBaseURL("http://[::1]:8080/v1"))
|
|
||||||
require.False(t, isLoopbackBaseURL("https://api.example.com/v1"))
|
|
||||||
require.False(t, isLoopbackBaseURL("://bad"))
|
|
||||||
require.False(t, shouldProbe(providerSettings{Name: ProviderOpenAI}))
|
|
||||||
require.True(t, shouldProbe(providerSettings{Name: ProviderOpenAICompatible, BaseURL: "http://localhost:8080/v1"}))
|
|
||||||
require.False(t, shouldProbe(providerSettings{Name: ProviderOpenAICompatible, BaseURL: "https://api.example.com/v1"}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProviderValidationEdges(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
_, err := resolveProviderConfig(config.EmbeddingsConfig{
|
|
||||||
Provider: ProviderOllama,
|
|
||||||
RequestTimeout: "not-a-duration",
|
|
||||||
}, true)
|
|
||||||
require.ErrorContains(t, err, "parse embeddings request_timeout")
|
|
||||||
|
|
||||||
_, err = resolveProviderConfig(config.EmbeddingsConfig{
|
|
||||||
Provider: ProviderOllama,
|
|
||||||
RequestTimeout: "0s",
|
|
||||||
}, true)
|
|
||||||
require.ErrorContains(t, err, "must be positive")
|
|
||||||
|
|
||||||
_, err = resolveProviderConfig(config.EmbeddingsConfig{
|
|
||||||
Provider: ProviderOllama,
|
|
||||||
BaseURL: "://bad",
|
|
||||||
}, true)
|
|
||||||
require.ErrorContains(t, err, "invalid embeddings base_url")
|
|
||||||
|
|
||||||
key, err := resolveAPIKey(ProviderOpenAICompatible, "MISSING_EMBED_KEY", false)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Empty(t, key)
|
|
||||||
|
|
||||||
_, err = newProvider(providerSettings{Name: "bogus"})
|
|
||||||
require.ErrorContains(t, err, "unsupported embedding provider")
|
|
||||||
|
|
||||||
require.Equal(t, []string{"abc"}, trimInputs([]string{"abc"}, 0))
|
|
||||||
_, err = inferDimensions([][]float32{{}})
|
|
||||||
require.ErrorContains(t, err, "empty vector")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOllamaProviderResponseEdges(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
countServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
assert.Equal(t, "/api/embed", r.URL.Path)
|
|
||||||
_, _ = w.Write([]byte(`{"embeddings":[]}`))
|
|
||||||
}))
|
|
||||||
defer countServer.Close()
|
|
||||||
|
|
||||||
provider := newOllamaProvider(providerSettings{
|
|
||||||
HTTPClient: countServer.Client(),
|
|
||||||
BaseURL: countServer.URL,
|
|
||||||
Model: "fallback-model",
|
|
||||||
MaxInputChars: 10,
|
|
||||||
})
|
|
||||||
_, err := provider.Embed(context.Background(), []string{"one"})
|
|
||||||
require.ErrorContains(t, err, "returned 0 vectors for 1 inputs")
|
|
||||||
|
|
||||||
modelServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
assert.Equal(t, "/api/embed", r.URL.Path)
|
|
||||||
_, _ = w.Write([]byte(`{"embeddings":[[1,2]]}`))
|
|
||||||
}))
|
|
||||||
defer modelServer.Close()
|
|
||||||
|
|
||||||
provider = newOllamaProvider(providerSettings{
|
|
||||||
HTTPClient: modelServer.Client(),
|
|
||||||
BaseURL: modelServer.URL,
|
|
||||||
Model: "fallback-model",
|
|
||||||
MaxInputChars: 10,
|
|
||||||
})
|
|
||||||
batch, err := provider.Embed(context.Background(), []string{"one"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, "fallback-model", batch.Model)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCheckProviderSkipsRemoteCompatibleProbe(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
result := CheckProvider(context.Background(), config.EmbeddingsConfig{
|
|
||||||
Provider: ProviderOpenAICompatible,
|
|
||||||
Model: "remote-model",
|
|
||||||
BaseURL: "https://api.example.com/v1",
|
|
||||||
RequestTimeout: "5s",
|
|
||||||
})
|
|
||||||
require.Equal(t, "ok", result.Status)
|
|
||||||
require.False(t, result.Probed)
|
|
||||||
require.Empty(t, result.Warning)
|
|
||||||
}
|
|
||||||
@ -1,15 +1,14 @@
|
|||||||
package store
|
package store
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/openclaw/discrawl/internal/embed"
|
"github.com/openclaw/crawlkit/embed"
|
||||||
|
"github.com/openclaw/crawlkit/vector"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -476,28 +475,23 @@ func capRunes(value string, maxChars int) string {
|
|||||||
return string(runes[:maxChars])
|
return string(runes[:maxChars])
|
||||||
}
|
}
|
||||||
|
|
||||||
func EncodeEmbeddingVector(vector []float32) ([]byte, error) {
|
func EncodeEmbeddingVector(values []float32) ([]byte, error) {
|
||||||
buf := bytes.NewBuffer(make([]byte, 0, len(vector)*4))
|
blob, err := vector.EncodeFloat32(values)
|
||||||
for _, value := range vector {
|
if err != nil {
|
||||||
if err := binary.Write(buf, binary.LittleEndian, value); err != nil {
|
return nil, fmt.Errorf("encode embedding vector: %w", err)
|
||||||
return nil, fmt.Errorf("encode embedding vector: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return buf.Bytes(), nil
|
return blob, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DecodeEmbeddingVector(blob []byte) ([]float32, error) {
|
func DecodeEmbeddingVector(blob []byte) ([]float32, error) {
|
||||||
if len(blob)%4 != 0 {
|
if len(blob)%4 != 0 {
|
||||||
return nil, fmt.Errorf("embedding blob length %d is not a float32 multiple", len(blob))
|
return nil, fmt.Errorf("embedding blob length %d is not a float32 multiple", len(blob))
|
||||||
}
|
}
|
||||||
out := make([]float32, len(blob)/4)
|
values, err := vector.DecodeFloat32(blob)
|
||||||
reader := bytes.NewReader(blob)
|
if err != nil {
|
||||||
for i := range out {
|
return nil, fmt.Errorf("decode embedding vector: %w", err)
|
||||||
if err := binary.Read(reader, binary.LittleEndian, &out[i]); err != nil {
|
|
||||||
return nil, fmt.Errorf("decode embedding vector: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return out, nil
|
return values, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) EmbeddingBacklog(ctx context.Context) (int, error) {
|
func (s *Store) EmbeddingBacklog(ctx context.Context) (int, error) {
|
||||||
|
|||||||
@ -5,11 +5,12 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
|
||||||
"os"
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/openclaw/crawlkit/vector"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -160,7 +161,7 @@ func (s *Store) SearchMessagesSemantic(ctx context.Context, opts SemanticSearchO
|
|||||||
if len(opts.QueryVector) != opts.Dimensions {
|
if len(opts.QueryVector) != opts.Dimensions {
|
||||||
return nil, fmt.Errorf("semantic query embedding dimensions mismatch: got %d want %d", len(opts.QueryVector), opts.Dimensions)
|
return nil, fmt.Errorf("semantic query embedding dimensions mismatch: got %d want %d", len(opts.QueryVector), opts.Dimensions)
|
||||||
}
|
}
|
||||||
queryNorm := vectorNorm(opts.QueryVector)
|
queryNorm := vector.Norm(opts.QueryVector)
|
||||||
if queryNorm == 0 {
|
if queryNorm == 0 {
|
||||||
return nil, errors.New("semantic query embedding returned a zero vector")
|
return nil, errors.New("semantic query embedding returned a zero vector")
|
||||||
}
|
}
|
||||||
@ -236,15 +237,18 @@ func (s *Store) SearchMessagesSemantic(ctx context.Context, opts SemanticSearchO
|
|||||||
if dimensions != opts.Dimensions {
|
if dimensions != opts.Dimensions {
|
||||||
return nil, fmt.Errorf("stored embedding dimensions mismatch for message %s: got %d want %d", row.MessageID, dimensions, opts.Dimensions)
|
return nil, fmt.Errorf("stored embedding dimensions mismatch for message %s: got %d want %d", row.MessageID, dimensions, opts.Dimensions)
|
||||||
}
|
}
|
||||||
vector, err := DecodeEmbeddingVector(blob)
|
storedVector, err := DecodeEmbeddingVector(blob)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("decode embedding for message %s: %w", row.MessageID, err)
|
return nil, fmt.Errorf("decode embedding for message %s: %w", row.MessageID, err)
|
||||||
}
|
}
|
||||||
if len(vector) != dimensions {
|
if len(storedVector) != dimensions {
|
||||||
return nil, fmt.Errorf("stored embedding vector length mismatch for message %s: got %d want %d", row.MessageID, len(vector), dimensions)
|
return nil, fmt.Errorf("stored embedding vector length mismatch for message %s: got %d want %d", row.MessageID, len(storedVector), dimensions)
|
||||||
}
|
}
|
||||||
score, err := cosineSimilarity(opts.QueryVector, queryNorm, vector)
|
score, err := vector.CosineSimilarity(opts.QueryVector, queryNorm, storedVector)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if strings.Contains(err.Error(), "candidate vector is zero") {
|
||||||
|
return nil, fmt.Errorf("score embedding for message %s: stored embedding vector is zero", row.MessageID)
|
||||||
|
}
|
||||||
return nil, fmt.Errorf("score embedding for message %s: %w", row.MessageID, err)
|
return nil, fmt.Errorf("score embedding for message %s: %w", row.MessageID, err)
|
||||||
}
|
}
|
||||||
row.CreatedAt = parseTime(created)
|
row.CreatedAt = parseTime(created)
|
||||||
@ -328,26 +332,23 @@ func fuseSearchResults(ftsResults, semanticResults []SearchResult, limit int) []
|
|||||||
if limit <= 0 {
|
if limit <= 0 {
|
||||||
limit = 20
|
limit = 20
|
||||||
}
|
}
|
||||||
entries := make(map[string]*hybridSearchEntry, len(ftsResults)+len(semanticResults))
|
id := func(result SearchResult) string {
|
||||||
addResults := func(results []SearchResult, weight float64, fts bool) {
|
return result.MessageID
|
||||||
for index, result := range results {
|
|
||||||
entry := entries[result.MessageID]
|
|
||||||
if entry == nil {
|
|
||||||
entry = &hybridSearchEntry{result: result}
|
|
||||||
entries[result.MessageID] = entry
|
|
||||||
}
|
|
||||||
if fts {
|
|
||||||
entry.hasFTS = true
|
|
||||||
}
|
|
||||||
entry.score += weight / (rrfK + float64(index+1))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
addResults(ftsResults, ftsRRFWeight, true)
|
ftsIDs := make(map[string]struct{}, len(ftsResults))
|
||||||
addResults(semanticResults, semanticRRFWeight, false)
|
for _, result := range ftsResults {
|
||||||
|
ftsIDs[result.MessageID] = struct{}{}
|
||||||
merged := make([]hybridSearchEntry, 0, len(entries))
|
}
|
||||||
for _, entry := range entries {
|
fused := vector.ReciprocalRankFusion(
|
||||||
merged = append(merged, *entry)
|
[][]SearchResult{ftsResults, semanticResults},
|
||||||
|
[]func(SearchResult) string{id, id},
|
||||||
|
[]float64{ftsRRFWeight, semanticRRFWeight},
|
||||||
|
rrfK,
|
||||||
|
)
|
||||||
|
merged := make([]hybridSearchEntry, 0, len(fused))
|
||||||
|
for _, entry := range fused {
|
||||||
|
_, hasFTS := ftsIDs[entry.Item.MessageID]
|
||||||
|
merged = append(merged, hybridSearchEntry{result: entry.Item, score: entry.Score, hasFTS: hasFTS})
|
||||||
}
|
}
|
||||||
sort.SliceStable(merged, func(i, j int) bool {
|
sort.SliceStable(merged, func(i, j int) bool {
|
||||||
if merged[i].score != merged[j].score {
|
if merged[i].score != merged[j].score {
|
||||||
@ -490,29 +491,6 @@ func (s *Store) searchFallback(ctx context.Context, opts SearchOptions) ([]Searc
|
|||||||
return out, rows.Err()
|
return out, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func cosineSimilarity(query []float32, queryNorm float64, vector []float32) (float64, error) {
|
|
||||||
if len(vector) != len(query) {
|
|
||||||
return 0, fmt.Errorf("dimensions mismatch: got %d want %d", len(vector), len(query))
|
|
||||||
}
|
|
||||||
vectorNorm := vectorNorm(vector)
|
|
||||||
if vectorNorm == 0 {
|
|
||||||
return 0, errors.New("stored embedding vector is zero")
|
|
||||||
}
|
|
||||||
var dot float64
|
|
||||||
for i := range query {
|
|
||||||
dot += float64(query[i]) * float64(vector[i])
|
|
||||||
}
|
|
||||||
return dot / (queryNorm * vectorNorm), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func vectorNorm(vector []float32) float64 {
|
|
||||||
var sum float64
|
|
||||||
for _, value := range vector {
|
|
||||||
sum += float64(value) * float64(value)
|
|
||||||
}
|
|
||||||
return math.Sqrt(sum)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) Members(ctx context.Context, guildID, query string, limit int) ([]MemberRow, error) {
|
func (s *Store) Members(ctx context.Context, guildID, query string, limit int) ([]MemberRow, error) {
|
||||||
if strings.TrimSpace(query) != "" {
|
if strings.TrimSpace(query) != "" {
|
||||||
return s.searchMembers(ctx, guildID, query, limit)
|
return s.searchMembers(ctx, guildID, query, limit)
|
||||||
|
|||||||
@ -10,9 +10,8 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/openclaw/crawlkit/embed"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/openclaw/discrawl/internal/embed"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestUpsertMessagesBatch(t *testing.T) {
|
func TestUpsertMessagesBatch(t *testing.T) {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user