diff --git a/internal/vector/exact.go b/internal/vector/exact.go new file mode 100644 index 0000000..8ada178 --- /dev/null +++ b/internal/vector/exact.go @@ -0,0 +1,59 @@ +package vector + +import ( + "math" + "sort" +) + +type Item struct { + ThreadID int64 + Vector []float64 +} + +type Neighbor struct { + ThreadID int64 `json:"thread_id"` + Score float64 `json:"score"` +} + +func Query(items []Item, query []float64, limit int, excludeThreadID int64) []Neighbor { + if limit <= 0 { + limit = 20 + } + out := make([]Neighbor, 0, len(items)) + for _, item := range items { + if item.ThreadID == excludeThreadID { + continue + } + score := Cosine(query, item.Vector) + if score <= 0 { + continue + } + out = append(out, Neighbor{ThreadID: item.ThreadID, Score: score}) + } + sort.SliceStable(out, func(i, j int) bool { + if out[i].Score == out[j].Score { + return out[i].ThreadID < out[j].ThreadID + } + return out[i].Score > out[j].Score + }) + if len(out) > limit { + out = out[:limit] + } + return out +} + +func Cosine(left, right []float64) float64 { + if len(left) == 0 || len(left) != len(right) { + return 0 + } + var dot, leftMag, rightMag float64 + for index := range left { + dot += left[index] * right[index] + leftMag += left[index] * left[index] + rightMag += right[index] * right[index] + } + if leftMag == 0 || rightMag == 0 { + return 0 + } + return dot / (math.Sqrt(leftMag) * math.Sqrt(rightMag)) +} diff --git a/internal/vector/exact_test.go b/internal/vector/exact_test.go new file mode 100644 index 0000000..d283350 --- /dev/null +++ b/internal/vector/exact_test.go @@ -0,0 +1,26 @@ +package vector + +import "testing" + +func TestCosine(t *testing.T) { + if got := Cosine([]float64{1, 0}, []float64{1, 0}); got != 1 { + t.Fatalf("cosine same: got %f want 1", got) + } + if got := Cosine([]float64{1, 0}, []float64{0, 1}); got != 0 { + t.Fatalf("cosine orthogonal: got %f want 0", got) + } +} + +func TestQuerySortsByScore(t *testing.T) { + got := Query([]Item{ + {ThreadID: 1, Vector: []float64{1, 0}}, + {ThreadID: 2, Vector: []float64{0.5, 0.5}}, + {ThreadID: 3, Vector: []float64{0, 1}}, + }, []float64{1, 0}, 2, 0) + if len(got) != 2 { + t.Fatalf("neighbors: got %d want 2", len(got)) + } + if got[0].ThreadID != 1 || got[1].ThreadID != 2 { + t.Fatalf("order: %#v", got) + } +}