feat: add exact vector search

This commit is contained in:
Vincent Koc 2026-04-26 23:40:37 -07:00
parent 1d3eb0dd79
commit 55b7b92dfb
No known key found for this signature in database
2 changed files with 85 additions and 0 deletions

59
internal/vector/exact.go Normal file
View File

@ -0,0 +1,59 @@
package vector
import (
"math"
"sort"
)
type Item struct {
ThreadID int64
Vector []float64
}
type Neighbor struct {
ThreadID int64 `json:"thread_id"`
Score float64 `json:"score"`
}
func Query(items []Item, query []float64, limit int, excludeThreadID int64) []Neighbor {
if limit <= 0 {
limit = 20
}
out := make([]Neighbor, 0, len(items))
for _, item := range items {
if item.ThreadID == excludeThreadID {
continue
}
score := Cosine(query, item.Vector)
if score <= 0 {
continue
}
out = append(out, Neighbor{ThreadID: item.ThreadID, Score: score})
}
sort.SliceStable(out, func(i, j int) bool {
if out[i].Score == out[j].Score {
return out[i].ThreadID < out[j].ThreadID
}
return out[i].Score > out[j].Score
})
if len(out) > limit {
out = out[:limit]
}
return out
}
func Cosine(left, right []float64) float64 {
if len(left) == 0 || len(left) != len(right) {
return 0
}
var dot, leftMag, rightMag float64
for index := range left {
dot += left[index] * right[index]
leftMag += left[index] * left[index]
rightMag += right[index] * right[index]
}
if leftMag == 0 || rightMag == 0 {
return 0
}
return dot / (math.Sqrt(leftMag) * math.Sqrt(rightMag))
}

View File

@ -0,0 +1,26 @@
package vector
import "testing"
func TestCosine(t *testing.T) {
if got := Cosine([]float64{1, 0}, []float64{1, 0}); got != 1 {
t.Fatalf("cosine same: got %f want 1", got)
}
if got := Cosine([]float64{1, 0}, []float64{0, 1}); got != 0 {
t.Fatalf("cosine orthogonal: got %f want 0", got)
}
}
func TestQuerySortsByScore(t *testing.T) {
got := Query([]Item{
{ThreadID: 1, Vector: []float64{1, 0}},
{ThreadID: 2, Vector: []float64{0.5, 0.5}},
{ThreadID: 3, Vector: []float64{0, 1}},
}, []float64{1, 0}, 2, 0)
if len(got) != 2 {
t.Fatalf("neighbors: got %d want 2", len(got))
}
if got[0].ThreadID != 1 || got[1].ThreadID != 2 {
t.Fatalf("order: %#v", got)
}
}