feat: add exact vector search
This commit is contained in:
parent
1d3eb0dd79
commit
55b7b92dfb
59
internal/vector/exact.go
Normal file
59
internal/vector/exact.go
Normal file
@ -0,0 +1,59 @@
|
||||
package vector
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sort"
|
||||
)
|
||||
|
||||
type Item struct {
|
||||
ThreadID int64
|
||||
Vector []float64
|
||||
}
|
||||
|
||||
type Neighbor struct {
|
||||
ThreadID int64 `json:"thread_id"`
|
||||
Score float64 `json:"score"`
|
||||
}
|
||||
|
||||
func Query(items []Item, query []float64, limit int, excludeThreadID int64) []Neighbor {
|
||||
if limit <= 0 {
|
||||
limit = 20
|
||||
}
|
||||
out := make([]Neighbor, 0, len(items))
|
||||
for _, item := range items {
|
||||
if item.ThreadID == excludeThreadID {
|
||||
continue
|
||||
}
|
||||
score := Cosine(query, item.Vector)
|
||||
if score <= 0 {
|
||||
continue
|
||||
}
|
||||
out = append(out, Neighbor{ThreadID: item.ThreadID, Score: score})
|
||||
}
|
||||
sort.SliceStable(out, func(i, j int) bool {
|
||||
if out[i].Score == out[j].Score {
|
||||
return out[i].ThreadID < out[j].ThreadID
|
||||
}
|
||||
return out[i].Score > out[j].Score
|
||||
})
|
||||
if len(out) > limit {
|
||||
out = out[:limit]
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func Cosine(left, right []float64) float64 {
|
||||
if len(left) == 0 || len(left) != len(right) {
|
||||
return 0
|
||||
}
|
||||
var dot, leftMag, rightMag float64
|
||||
for index := range left {
|
||||
dot += left[index] * right[index]
|
||||
leftMag += left[index] * left[index]
|
||||
rightMag += right[index] * right[index]
|
||||
}
|
||||
if leftMag == 0 || rightMag == 0 {
|
||||
return 0
|
||||
}
|
||||
return dot / (math.Sqrt(leftMag) * math.Sqrt(rightMag))
|
||||
}
|
||||
26
internal/vector/exact_test.go
Normal file
26
internal/vector/exact_test.go
Normal file
@ -0,0 +1,26 @@
|
||||
package vector
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestCosine(t *testing.T) {
|
||||
if got := Cosine([]float64{1, 0}, []float64{1, 0}); got != 1 {
|
||||
t.Fatalf("cosine same: got %f want 1", got)
|
||||
}
|
||||
if got := Cosine([]float64{1, 0}, []float64{0, 1}); got != 0 {
|
||||
t.Fatalf("cosine orthogonal: got %f want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuerySortsByScore(t *testing.T) {
|
||||
got := Query([]Item{
|
||||
{ThreadID: 1, Vector: []float64{1, 0}},
|
||||
{ThreadID: 2, Vector: []float64{0.5, 0.5}},
|
||||
{ThreadID: 3, Vector: []float64{0, 1}},
|
||||
}, []float64{1, 0}, 2, 0)
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("neighbors: got %d want 2", len(got))
|
||||
}
|
||||
if got[0].ThreadID != 1 || got[1].ThreadID != 2 {
|
||||
t.Fatalf("order: %#v", got)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user