From f60140892c55aaaa282891559e501847e36c6d2a Mon Sep 17 00:00:00 2001 From: Mariano Belinky Date: Sun, 4 Jan 2026 18:09:24 +0100 Subject: [PATCH] Add routing plugin + default policy --- README.md | 89 ++ cmd/clawgo/main.go | 1337 +++++++++++++++++ go.mod | 13 + go.sum | 26 + internal/routing/policy/default/policy.go | 80 + .../routing/policy/default/policy_test.go | 110 ++ internal/routing/routing.go | 61 + 7 files changed, 1716 insertions(+) create mode 100644 README.md create mode 100644 cmd/clawgo/main.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/routing/policy/default/policy.go create mode 100644 internal/routing/policy/default/policy_test.go create mode 100644 internal/routing/routing.go diff --git a/README.md b/README.md new file mode 100644 index 0000000..7ebcf33 --- /dev/null +++ b/README.md @@ -0,0 +1,89 @@ +# Clawgo (Go Node) + +Minimal headless node client for Raspberry Pi / Linux. Connects to the gateway bridge, handles pairing, streams `voice.transcript` events (stdin/FIFO), subscribes to chat, and can speak responses via local TTS. + +## Build + +```bash +cd clawgo +go build ./cmd/clawgo +``` + +Cross-compile for Pi: + +```bash +GOOS=linux GOARCH=arm64 go build -o /tmp/clawgo-linux-arm64 ./cmd/clawgo +``` + +## Key flags + +| Flag | Description | +| --- | --- | +| `-session-key` | Session for outgoing `voice.transcript` events (default `main`). | +| `-chat-session-key` | Session to subscribe for chat replies (default mirrors `-session-key`). | +| `-chat-subscribe` | Enable chat stream+TTS (default `true`). | +| `-tts-engine` | `system`, `piper`, `elevenlabs`, or `none` (system = `espeak-ng`). | +| `-tts-system-voice` | espeak voice id (default `en-us`). | +| `-tts-system-rate` | Speech rate (wpm). | +| `-mdns-service` | Bonjour service type (default `_clawdis-node._tcp`). | +| `-stdin` | Read transcripts from stdin (pipe/FIFO). | +| `-stdin-file` | Read transcripts from a FIFO/file instead of stdin. | +| `-agent-request` | Send transcripts as `agent.request` (uses agent + deliver). | +| `-deliver` | Deliver agent responses to a provider (requires channel + to). | +| `-deliver-channel` | Delivery provider (telegram/whatsapp/signal/imessage). | +| `-deliver-to` | Delivery destination id. | +| `-quick-actions` | Enable built-in quick actions (default true). | +| `-ping-message` | Message used for telegram ping quick action. | + +## Pair + +```bash +./clawgo pair \ + -bridge 100.88.46.29:18790 \ + -display-name "Razor Pi" +``` + +Approve via `clawdis nodes approve `. + +## Run (FIFO + TTS example) + +```bash +mkfifo /tmp/voice.fifo +# in one terminal +tail -f /tmp/voice.fifo | ./clawgo run \ + -bridge 100.88.46.29:18790 \ + -stdin \ + -chat-subscribe \ + -tts-engine system +# elsewhere +printf hey computer turn on the lightsn > /tmp/voice.fifo +``` + +Each line on the FIFO becomes a `voice.transcript`; chat responses from the `main` session are spoken via `espeak-ng`. + +## systemd example + +See `docs/linux-node.md` for the end-to-end Pi setup. TL;DR: + +1. Install the binary as `/home/pi/clawgo`. +2. Create a wrapper script that keeps a FIFO (`/home/pi/.cache/clawdis/voice.fifo`) open and pipes it into `clawgo run -stdin`. +3. Create `/etc/systemd/system/clawgo.service` pointing to that wrapper. + +## mDNS advertising + +The node advertises `_clawdis-node._tcp` by default. + +```bash +dns-sd -B _clawdis-node._tcp local. +``` + +Override to `_clawdis-bridge._tcp` if you intentionally want it to show up as a gateway beacon: + +```bash +./clawgo run -mdns-service _clawdis-bridge._tcp +``` + +## Notes +- Node state (`nodeId` + token) lives in `~/.clawdis/clawgo.json`. +- Caps default to `voiceWake`; override via `-caps` if you expose more commands. +- Set `bridge.bind: "tailnet"` on the gateway to restrict the bridge to Tailscale. diff --git a/cmd/clawgo/main.go b/cmd/clawgo/main.go new file mode 100644 index 0000000..339cbb2 --- /dev/null +++ b/cmd/clawgo/main.go @@ -0,0 +1,1337 @@ +package main + +import ( + "bufio" + "context" + "crypto/rand" + "encoding/hex" + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "net" + "os" + "os/exec" + "os/signal" + "path/filepath" + "strconv" + "strings" + "sync" + "syscall" + "time" + + "github.com/clawdbot/clawgo/internal/routing" + _ "github.com/clawdbot/clawgo/internal/routing/policy/default" + + "github.com/grandcat/zeroconf" +) + +type NodeState struct { + NodeID string `json:"nodeId"` + Token string `json:"token,omitempty"` + DisplayName string `json:"displayName,omitempty"` +} + +type BridgeClient struct { + conn net.Conn + mu sync.Mutex + logf func(string, ...any) + done chan struct{} + errs chan error + frames chan map[string]any + eventMu sync.RWMutex + eventHandler func(string, string) +} + +type NodeConfig struct { + BridgeAddr string + StatePath string + NodeID string + DisplayName string + Platform string + Version string + DeviceFamily string + ModelIdentifier string + Caps []string + Commands []string + SessionKey string + PingInterval time.Duration + StdinMode bool + MDNSEnabled bool + MDNSService string + MDNSDomain string + MDNSName string + ChatSessionKey string + ChatSubscribe bool + AgentRequest bool + Deliver bool + DeliverChannel string + DeliverTo string + TTSEngine string + TTSSystemVoice string + TTSSystemRate int + TTSSystemCommand string + StdinPath string + QuickActions bool + QuickPingMessage string + RoutingPlugin string +} + +func main() { + if len(os.Args) < 2 { + usage() + os.Exit(2) + } + + cmd := os.Args[1] + switch cmd { + case "pair": + cfg := parseFlags(cmd, os.Args[2:]) + if err := runPair(cfg); err != nil { + fatal(err) + } + case "run": + cfg := parseFlags(cmd, os.Args[2:]) + if err := runNode(cfg); err != nil { + fatal(err) + } + default: + usage() + os.Exit(2) + } +} + +func usage() { + fmt.Fprintln(os.Stderr, "Usage: clawgo [flags]") + fmt.Fprintln(os.Stderr, "") + fmt.Fprintln(os.Stderr, "Common flags:") + fmt.Fprintln(os.Stderr, " -bridge Bridge host:port (default 127.0.0.1:18790)") + fmt.Fprintln(os.Stderr, " -state Path to node state JSON (default ~/.clawdis/linux-node.json)") + fmt.Fprintln(os.Stderr, " -node-id Override node id (default derived from hostname)") + fmt.Fprintln(os.Stderr, " -display-name Friendly display name (default hostname)") + fmt.Fprintln(os.Stderr, " -platform Platform label (default linux)") + fmt.Fprintln(os.Stderr, " -version Client version string (default dev)") + fmt.Fprintln(os.Stderr, " -device-family Device family (default raspi or linux)") + fmt.Fprintln(os.Stderr, " -model-identifier Model identifier override") + fmt.Fprintln(os.Stderr, " -caps Comma-separated caps (default voiceWake)") + fmt.Fprintln(os.Stderr, " -commands Comma-separated commands (default none)") + fmt.Fprintln(os.Stderr, " -session-key Session key for voice.transcript (default main)") + fmt.Fprintln(os.Stderr, " -chat-session-key Session key for chat.subscribe (default main)") + fmt.Fprintln(os.Stderr, " -chat-subscribe Subscribe to chat stream for TTS (default true)") + fmt.Fprintln(os.Stderr, " -agent-request Send agent.request instead of voice.transcript") + fmt.Fprintln(os.Stderr, " -deliver Deliver agent response to a channel") + fmt.Fprintln(os.Stderr, " -deliver-channel Channel for delivery (telegram|whatsapp|signal|imessage)") + fmt.Fprintln(os.Stderr, " -deliver-to Destination id (e.g. telegram chat id)") + fmt.Fprintln(os.Stderr, " -tts-engine TTS engine (system, piper, elevenlabs, none)") + fmt.Fprintln(os.Stderr, " -tts-system-voice Voice id for system TTS (default en-us)") + fmt.Fprintln(os.Stderr, " -tts-system-rate Speech rate for system TTS (default 180)") + fmt.Fprintln(os.Stderr, " -tts-system-command Binary for system TTS (default espeak-ng)") + fmt.Fprintln(os.Stderr, " -stdin Read stdin lines and send voice.transcript events") + fmt.Fprintln(os.Stderr, " -stdin-file Read lines from a file/FIFO instead of stdin") + fmt.Fprintln(os.Stderr, " -ping-interval Ping interval (default 30s)") + fmt.Fprintln(os.Stderr, " -quick-actions Enable built-in quick actions (default true)") + fmt.Fprintln(os.Stderr, " -router Routing plugin name (default default)") + fmt.Fprintln(os.Stderr, " -ping-message Message used for telegram ping quick action (default \"Ping.\")") + fmt.Fprintln(os.Stderr, " -mdns Advertise mDNS presence (default true)") + fmt.Fprintln(os.Stderr, " -mdns-service mDNS service type (default _clawdis-node._tcp)") + fmt.Fprintln(os.Stderr, " -mdns-domain mDNS domain (default local.)") + fmt.Fprintln(os.Stderr, " -mdns-name mDNS instance name override") +} + +func parseFlags(cmd string, args []string) NodeConfig { + fs := flag.NewFlagSet(cmd, flag.ExitOnError) + bridge := fs.String("bridge", "127.0.0.1:18790", "bridge host:port") + state := fs.String("state", defaultStatePath(), "path to node state JSON") + nodeID := fs.String("node-id", "", "override node id") + displayName := fs.String("display-name", "", "display name") + platform := fs.String("platform", "linux", "platform label") + version := fs.String("version", "dev", "client version string") + deviceFamily := fs.String("device-family", "", "device family") + modelIdentifier := fs.String("model-identifier", "", "model identifier") + caps := fs.String("caps", "voiceWake", "comma-separated caps") + commands := fs.String("commands", "", "comma-separated commands") + sessionKey := fs.String("session-key", "main", "session key for voice.transcript") + chatSessionKey := fs.String("chat-session-key", "main", "session key for chat.subscribe") + chatSubscribe := fs.Bool("chat-subscribe", true, "subscribe to chat events for TTS") + agentRequest := fs.Bool("agent-request", false, "send agent.request instead of voice.transcript") + deliver := fs.Bool("deliver", false, "deliver agent response to channel") + deliverChannel := fs.String("deliver-channel", "", "deliver channel (telegram|whatsapp|signal|imessage)") + deliverTo := fs.String("deliver-to", "", "deliver destination id") + ttsEngine := fs.String("tts-engine", "system", "TTS engine (system, piper, elevenlabs, none)") + ttsSystemVoice := fs.String("tts-system-voice", "en-us", "voice id for system TTS") + ttsSystemRate := fs.Int("tts-system-rate", 180, "speech rate for system TTS") + ttsSystemCommand := fs.String("tts-system-command", "espeak-ng", "binary for system TTS") + stdinMode := fs.Bool("stdin", false, "read stdin lines for voice.transcript") + stdinFile := fs.String("stdin-file", "", "read input lines from file/FIFO") + pingInterval := fs.Duration("ping-interval", 30*time.Second, "ping interval") + quickActions := fs.Bool("quick-actions", true, "enable built-in quick actions") + quickPingMessage := fs.String("ping-message", "Ping.", "message used for telegram ping quick action") + router := fs.String("router", "default", "routing plugin name") + mdnsEnabled := fs.Bool("mdns", true, "advertise mDNS presence") + mdnsService := fs.String("mdns-service", "_clawdis-node._tcp", "mDNS service type") + mdnsDomain := fs.String("mdns-domain", "local.", "mDNS domain") + mdnsName := fs.String("mdns-name", "", "mDNS instance name override") + _ = fs.Parse(args) + + family := strings.TrimSpace(*deviceFamily) + if family == "" { + if detected := detectDeviceFamily(); detected != "" { + family = detected + } + } + if family == "" { + family = "linux" + } + + cfg := NodeConfig{ + BridgeAddr: strings.TrimSpace(*bridge), + StatePath: strings.TrimSpace(*state), + NodeID: strings.TrimSpace(*nodeID), + DisplayName: strings.TrimSpace(*displayName), + Platform: strings.TrimSpace(*platform), + Version: strings.TrimSpace(*version), + DeviceFamily: family, + ModelIdentifier: strings.TrimSpace(*modelIdentifier), + Caps: splitCSV(*caps), + Commands: splitCSV(*commands), + SessionKey: strings.TrimSpace(*sessionKey), + StdinMode: *stdinMode, + PingInterval: *pingInterval, + MDNSEnabled: *mdnsEnabled, + MDNSService: strings.TrimSpace(*mdnsService), + MDNSDomain: strings.TrimSpace(*mdnsDomain), + MDNSName: strings.TrimSpace(*mdnsName), + ChatSessionKey: strings.TrimSpace(*chatSessionKey), + ChatSubscribe: *chatSubscribe, + AgentRequest: *agentRequest, + Deliver: *deliver, + DeliverChannel: strings.TrimSpace(*deliverChannel), + DeliverTo: strings.TrimSpace(*deliverTo), + TTSEngine: strings.TrimSpace(*ttsEngine), + TTSSystemVoice: strings.TrimSpace(*ttsSystemVoice), + TTSSystemRate: *ttsSystemRate, + TTSSystemCommand: strings.TrimSpace(*ttsSystemCommand), + StdinPath: strings.TrimSpace(*stdinFile), + QuickActions: *quickActions, + QuickPingMessage: strings.TrimSpace(*quickPingMessage), + RoutingPlugin: strings.TrimSpace(*router), + } + return cfg +} + +func runPair(cfg NodeConfig) error { + state, err := loadOrInitState(cfg) + if err != nil { + return err + } + if cfg.NodeID != "" { + state.NodeID = cfg.NodeID + } + if cfg.DisplayName != "" { + state.DisplayName = cfg.DisplayName + } + + client, err := connectBridge(cfg.BridgeAddr) + if err != nil { + return err + } + defer client.Close() + + client.logf("connected to bridge %s", cfg.BridgeAddr) + if err := sendPairRequest(client, cfg, state); err != nil { + return err + } + + token, err := waitForPair(client) + if err != nil { + return err + } + state.Token = token + if err := saveState(cfg.StatePath, state); err != nil { + return err + } + client.logf("paired ok; token saved to %s", cfg.StatePath) + return nil +} + +func runNode(cfg NodeConfig) error { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer cancel() + + logf := func(format string, args ...any) { fmt.Fprintf(os.Stderr, format+"\n", args...) } + + state, err := loadOrInitState(cfg) + if err != nil { + return err + } + if cfg.NodeID != "" { + state.NodeID = cfg.NodeID + } + if cfg.DisplayName != "" { + state.DisplayName = cfg.DisplayName + } + + var stdinCh chan string + if cfg.StdinMode || cfg.StdinPath != "" { + stdinCh = make(chan string, 32) + if cfg.StdinPath != "" { + go readLinesFromPathLoop(ctx, logf, cfg.StdinPath, stdinCh) + } else { + go readLinesFromReader(ctx, logf, os.Stdin, "stdin", stdinCh) + } + } + + var mdnsCleanup func() + mdnsStarted := false + backoff := time.Second + + for { + select { + case <-ctx.Done(): + if mdnsCleanup != nil { + mdnsCleanup() + } + return nil + default: + } + + client, err := connectBridge(cfg.BridgeAddr) + if err != nil { + logf("bridge connect failed: %v", err) + time.Sleep(backoff) + if backoff < 15*time.Second { + backoff *= 2 + if backoff > 15*time.Second { + backoff = 15 * time.Second + } + } + continue + } + backoff = time.Second + client.logf("connected to bridge %s", cfg.BridgeAddr) + + if strings.TrimSpace(state.Token) == "" { + client.logf("no token found; requesting pairing") + if err := sendPairRequest(client, cfg, state); err != nil { + client.Close() + return err + } + token, err := waitForPair(client) + if err != nil { + client.Close() + return err + } + state.Token = token + if err := saveState(cfg.StatePath, state); err != nil { + client.Close() + return err + } + client.logf("paired ok; token saved to %s", cfg.StatePath) + } + + if err := sendHello(client, cfg, state); err != nil { + client.Close() + return err + } + if err := waitForHello(client); err != nil { + client.Close() + return err + } + if cfg.MDNSEnabled && !mdnsStarted { + mdnsCleanup = startMDNS(cfg, state, client.logf) + mdnsStarted = true + } + + var chatHandler *ChatSubscriber + if cfg.ChatSubscribe { + sessionKey := strings.TrimSpace(cfg.ChatSessionKey) + if sessionKey == "" { + sessionKey = strings.TrimSpace(cfg.SessionKey) + } + ttsQueue, err := buildTTSEngine(cfg, client.logf) + if err != nil { + client.logf("tts disabled: %v", err) + } else if ttsQueue != nil && sessionKey != "" { + chatHandler = newChatSubscriber(sessionKey, ttsQueue, client.logf) + if err := subscribeChat(client, sessionKey); err != nil { + client.logf("chat.subscribe failed: %v", err) + } else { + client.logf("chat.subscribe sessionKey=%s", sessionKey) + } + } + } + if chatHandler != nil { + client.setEventHandler(func(evt, payload string) { + if evt == "chat" { + chatHandler.Handle(payload) + } + }) + } + + var router routing.Router + if stdinCh != nil { + routerCfg := routing.Config{ + SessionKey: cfg.SessionKey, + AgentRequest: cfg.AgentRequest, + Deliver: cfg.Deliver, + DeliverChannel: cfg.DeliverChannel, + DeliverTo: cfg.DeliverTo, + QuickActions: cfg.QuickActions, + QuickPingMessage: cfg.QuickPingMessage, + } + r, err := routing.New(cfg.RoutingPlugin, routerCfg, bridgeTransport{client: client}, client.logf) + if err != nil { + client.logf("router init failed: %v", err) + } else { + router = r + } + } + connCtx, connCancel := context.WithCancel(ctx) + if stdinCh != nil { + go forwardStdin(connCtx, client, cfg, stdinCh, router) + } + if cfg.PingInterval > 0 { + go pingLoop(connCtx, client, cfg.PingInterval) + } + + for { + select { + case <-ctx.Done(): + connCancel() + client.Close() + if mdnsCleanup != nil { + mdnsCleanup() + } + return nil + case err := <-client.errs: + if err != nil { + client.logf("bridge error: %v", err) + } + connCancel() + client.Close() + goto reconnect + case frame := <-client.frames: + if frame == nil { + continue + } + if err := handleFrame(client, frame); err != nil { + client.logf("frame error: %v", err) + connCancel() + client.Close() + goto reconnect + } + } + } + + reconnect: + if ctx.Err() != nil { + if mdnsCleanup != nil { + mdnsCleanup() + } + return nil + } + time.Sleep(backoff) + if backoff < 15*time.Second { + backoff *= 2 + if backoff > 15*time.Second { + backoff = 15 * time.Second + } + } + } +} + +func connectBridge(addr string) (*BridgeClient, error) { + if addr == "" { + return nil, errors.New("bridge address required") + } + conn, err := net.DialTimeout("tcp", addr, 5*time.Second) + if err != nil { + return nil, err + } + client := &BridgeClient{ + conn: conn, + logf: func(format string, args ...any) { fmt.Fprintf(os.Stderr, format+"\n", args...) }, + done: make(chan struct{}), + errs: make(chan error, 1), + frames: make(chan map[string]any, 16), + } + go client.readLoop() + return client, nil +} + +func (c *BridgeClient) Close() { + select { + case <-c.done: + default: + close(c.done) + } + _ = c.conn.Close() +} + +func (c *BridgeClient) sendFrame(frame any) error { + payload, err := json.Marshal(frame) + if err != nil { + return err + } + c.mu.Lock() + defer c.mu.Unlock() + _, err = c.conn.Write(append(payload, 10)) + return err +} + +func (c *BridgeClient) setEventHandler(fn func(string, string)) { + c.eventMu.Lock() + c.eventHandler = fn + c.eventMu.Unlock() +} + +func (c *BridgeClient) dispatchEvent(evt, payload string) { + c.eventMu.RLock() + handler := c.eventHandler + c.eventMu.RUnlock() + if handler != nil { + handler(evt, payload) + } +} + +func (c *BridgeClient) readLoop() { + scanner := bufio.NewScanner(c.conn) + scanner.Buffer(make([]byte, 0, 64*1024), 8*1024*1024) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + var frame map[string]any + if err := json.Unmarshal([]byte(line), &frame); err != nil { + c.logf("invalid frame: %v", err) + continue + } + select { + case c.frames <- frame: + case <-c.done: + return + } + } + if err := scanner.Err(); err != nil && !errors.Is(err, io.EOF) { + c.errs <- err + return + } + c.errs <- io.EOF +} + +func sendPairRequest(c *BridgeClient, cfg NodeConfig, state *NodeState) error { + frame := map[string]any{ + "type": "pair-request", + "nodeId": state.NodeID, + "displayName": state.DisplayName, + "platform": cfg.Platform, + "version": cfg.Version, + "deviceFamily": cfg.DeviceFamily, + "modelIdentifier": cfg.ModelIdentifier, + "caps": cfg.Caps, + "commands": cfg.Commands, + "permissions": map[string]bool{}, + } + return c.sendFrame(frame) +} + +func sendHello(c *BridgeClient, cfg NodeConfig, state *NodeState) error { + frame := map[string]any{ + "type": "hello", + "nodeId": state.NodeID, + "displayName": state.DisplayName, + "token": state.Token, + "platform": cfg.Platform, + "version": cfg.Version, + "deviceFamily": cfg.DeviceFamily, + "modelIdentifier": cfg.ModelIdentifier, + "caps": cfg.Caps, + "commands": cfg.Commands, + "permissions": map[string]bool{}, + } + return c.sendFrame(frame) +} + +func subscribeChat(c *BridgeClient, sessionKey string) error { + if strings.TrimSpace(sessionKey) == "" { + return nil + } + payload, err := json.Marshal(map[string]string{"sessionKey": sessionKey}) + if err != nil { + return err + } + return c.sendFrame(map[string]any{ + "type": "event", + "event": "chat.subscribe", + "payloadJSON": string(payload), + }) +} + +func waitForPair(c *BridgeClient) (string, error) { + deadline := time.After(6 * time.Minute) + for { + select { + case <-deadline: + return "", errors.New("pairing timeout") + case err := <-c.errs: + return "", err + case frame := <-c.frames: + if frame == nil { + continue + } + if err := handleFrame(c, frame); err != nil { + return "", err + } + if frameType(frame) == "pair-ok" { + token := frameString(frame, "token") + if token == "" { + return "", errors.New("pair-ok missing token") + } + return token, nil + } + } + } +} + +func waitForHello(c *BridgeClient) error { + deadline := time.After(30 * time.Second) + for { + select { + case <-deadline: + return errors.New("hello timeout") + case err := <-c.errs: + return err + case frame := <-c.frames: + if frame == nil { + continue + } + if err := handleFrame(c, frame); err != nil { + return err + } + if frameType(frame) == "hello-ok" { + serverName := frameString(frame, "serverName") + canvasHost := frameString(frame, "canvasHostUrl") + if serverName != "" { + c.logf("hello ok (server=%s)", serverName) + } + if canvasHost != "" { + c.logf("canvas host: %s", canvasHost) + } + return nil + } + } + } +} + +func handleFrame(c *BridgeClient, frame map[string]any) error { + switch frameType(frame) { + case "ping": + id := frameString(frame, "id") + if id != "" { + _ = c.sendFrame(map[string]any{"type": "pong", "id": id}) + } + case "error": + code := frameString(frame, "code") + msg := frameString(frame, "message") + if code != "" || msg != "" { + return fmt.Errorf("bridge error: %s %s", code, msg) + } + case "invoke": + id := frameString(frame, "id") + if id != "" { + _ = c.sendFrame(map[string]any{ + "type": "invoke-res", + "id": id, + "ok": false, + "error": map[string]any{ + "code": "UNAVAILABLE", + "message": "headless node has no commands", + }, + }) + } + case "req": + id := frameString(frame, "id") + if id != "" { + _ = c.sendFrame(map[string]any{ + "type": "res", + "id": id, + "ok": false, + "error": map[string]any{ + "code": "UNAVAILABLE", + "message": "headless node has no RPC", + }, + }) + } + case "event": + evt := frameString(frame, "event") + if evt != "" { + payload := frameString(frame, "payloadJSON") + c.logf("event: %s", evt) + c.dispatchEvent(evt, payload) + } + } + return nil +} + +func readLinesFromReader(ctx context.Context, logf func(string, ...any), r io.Reader, label string, out chan<- string) { + scanner := bufio.NewScanner(r) + for scanner.Scan() { + select { + case <-ctx.Done(): + return + default: + } + text := strings.TrimSpace(scanner.Text()) + if text == "" { + continue + } + select { + case out <- text: + case <-ctx.Done(): + return + } + } + if err := scanner.Err(); err != nil { + logf("%s error: %v", label, err) + } +} + +func readLinesFromPathLoop(ctx context.Context, logf func(string, ...any), path string, out chan<- string) { + for { + select { + case <-ctx.Done(): + return + default: + } + f, err := os.Open(path) + if err != nil { + logf("stdin-file open failed: %v", err) + time.Sleep(1 * time.Second) + continue + } + readLinesFromReader(ctx, logf, f, path, out) + _ = f.Close() + time.Sleep(200 * time.Millisecond) + } +} + +func forwardStdin(ctx context.Context, c *BridgeClient, cfg NodeConfig, in <-chan string, router routing.Router) { + for { + select { + case <-ctx.Done(): + return + case text, ok := <-in: + if !ok { + return + } + if router != nil { + handled, err := router.HandleTranscript(ctx, text) + if err != nil { + c.logf("router error: %v", err) + } + if handled { + continue + } + } + if cfg.AgentRequest { + if err := sendAgentRequest(c, cfg, text); err != nil { + c.logf("agent request error: %v", err) + } + continue + } + if err := sendVoiceTranscript(c, cfg, text); err != nil { + c.logf("voice transcript error: %v", err) + } + } + } +} + +type bridgeTransport struct { + client *BridgeClient +} + +func (t bridgeTransport) SendVoiceTranscript(sessionKey, text string) error { + return sendVoiceTranscript(t.client, NodeConfig{SessionKey: sessionKey}, text) +} + +func (t bridgeTransport) SendAgentRequest(sessionKey, text string, deliver bool, channel, to string) error { + cfg := NodeConfig{SessionKey: sessionKey, Deliver: deliver, DeliverChannel: channel, DeliverTo: to} + return sendAgentRequest(t.client, cfg, text) +} + +func (t bridgeTransport) SendProviderMessage(provider, to, message string) error { + return sendProviderMessage(t.client, provider, to, message) +} + +func sendVoiceTranscript(c *BridgeClient, cfg NodeConfig, text string) error { + if strings.TrimSpace(text) == "" { + return nil + } + payload := map[string]any{"text": text} + if cfg.SessionKey != "" { + payload["sessionKey"] = cfg.SessionKey + } + payloadJSON, err := json.Marshal(payload) + if err != nil { + c.logf("voice payload error: %v", err) + return err + } + return c.sendFrame(map[string]any{ + "type": "event", + "event": "voice.transcript", + "payloadJSON": string(payloadJSON), + }) +} + +func sendAgentRequest(c *BridgeClient, cfg NodeConfig, text string) error { + if strings.TrimSpace(text) == "" { + return nil + } + payload := map[string]any{"message": text} + if cfg.SessionKey != "" { + payload["sessionKey"] = cfg.SessionKey + } + if cfg.Deliver && cfg.DeliverChannel != "" { + payload["deliver"] = true + payload["channel"] = cfg.DeliverChannel + if cfg.DeliverTo != "" { + payload["to"] = cfg.DeliverTo + } + } + payloadJSON, err := json.Marshal(payload) + if err != nil { + c.logf("agent payload error: %v", err) + return err + } + return c.sendFrame(map[string]any{ + "type": "event", + "event": "agent.request", + "payloadJSON": string(payloadJSON), + }) +} + +func sendProviderMessage(c *BridgeClient, provider, to, message string) error { + idem := randomID(12) + payload := map[string]any{ + "to": to, + "message": message, + "provider": provider, + "idempotencyKey": idem, + } + paramsJSON, err := json.Marshal(payload) + if err != nil { + return err + } + return c.sendFrame(map[string]any{ + "type": "req", + "id": randomID(8), + "method": "send", + "paramsJSON": string(paramsJSON), + }) +} + +func pingLoop(ctx context.Context, c *BridgeClient, interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + id := randomID(8) + _ = c.sendFrame(map[string]any{"type": "ping", "id": id}) + } + } +} + +type TTSEngine interface { + Speak(string) error +} + +type TTSQueue struct { + engine TTSEngine + queue chan string + logf func(string, ...any) +} + +func newTTSQueue(engine TTSEngine, logf func(string, ...any)) *TTSQueue { + if engine == nil { + return nil + } + q := &TTSQueue{ + engine: engine, + queue: make(chan string, 16), + logf: logf, + } + go q.loop() + return q +} + +func (q *TTSQueue) loop() { + for text := range q.queue { + if err := q.engine.Speak(text); err != nil { + q.logf("tts error: %v", err) + } + } +} + +func (q *TTSQueue) Speak(text string) { + trimmed := strings.TrimSpace(text) + if trimmed == "" { + return + } + select { + case q.queue <- trimmed: + default: + q.logf("tts queue full; dropping text") + } +} + +type systemTTSEngine struct { + command string + voice string + rate int +} + +func newSystemTTSEngine(cmd, voice string, rate int) (*systemTTSEngine, error) { + if cmd == "" { + cmd = "espeak-ng" + } + resolved, err := exec.LookPath(cmd) + if err != nil { + return nil, err + } + return &systemTTSEngine{command: resolved, voice: voice, rate: rate}, nil +} + +func (s *systemTTSEngine) Speak(text string) error { + trimmed := strings.TrimSpace(text) + if trimmed == "" { + return nil + } + args := []string{} + if s.voice != "" { + args = append(args, "-v", s.voice) + } + if s.rate > 0 { + args = append(args, "-s", strconv.Itoa(s.rate)) + } + args = append(args, trimmed) + cmd := exec.Command(s.command, args...) + cmd.Stdout = io.Discard + cmd.Stderr = io.Discard + return cmd.Run() +} + +func buildTTSEngine(cfg NodeConfig, logf func(string, ...any)) (*TTSQueue, error) { + switch strings.ToLower(strings.TrimSpace(cfg.TTSEngine)) { + case "", "none": + return nil, nil + case "system": + engine, err := newSystemTTSEngine(cfg.TTSSystemCommand, cfg.TTSSystemVoice, cfg.TTSSystemRate) + if err != nil { + return nil, err + } + return newTTSQueue(engine, logf), nil + default: + return nil, fmt.Errorf("unsupported tts engine: %s", cfg.TTSEngine) + } +} + +type ChatSubscriber struct { + sessionKey string + tts *TTSQueue + logf func(string, ...any) + mu sync.Mutex + buffers map[string]*strings.Builder +} + +func newChatSubscriber(sessionKey string, tts *TTSQueue, logf func(string, ...any)) *ChatSubscriber { + return &ChatSubscriber{ + sessionKey: strings.TrimSpace(sessionKey), + tts: tts, + logf: logf, + buffers: make(map[string]*strings.Builder), + } +} + +func (c *ChatSubscriber) Handle(payloadJSON string) { + if strings.TrimSpace(payloadJSON) == "" { + return + } + var payload chatPayload + if err := json.Unmarshal([]byte(payloadJSON), &payload); err != nil { + c.logf("chat payload decode failed: %v", err) + return + } + if c.sessionKey != "" && payload.SessionKey != "" && payload.SessionKey != c.sessionKey { + return + } + switch payload.State { + case "delta": + if text := payload.text(); text != "" { + c.append(payload.RunID, text) + } + case "final": + text := payload.text() + if text == "" { + text = c.consume(payload.RunID) + } else { + c.clear(payload.RunID) + } + c.speak(text) + case "error": + c.clear(payload.RunID) + default: + if text := payload.text(); text != "" { + c.speak(text) + } + } +} + +func (c *ChatSubscriber) append(runID, text string) { + clean := strings.TrimSpace(text) + if clean == "" { + return + } + c.mu.Lock() + defer c.mu.Unlock() + buf := c.buffers[runID] + if buf == nil { + buf = &strings.Builder{} + c.buffers[runID] = buf + } + if buf.Len() > 0 { + buf.WriteString(" ") + } + buf.WriteString(clean) +} + +func (c *ChatSubscriber) consume(runID string) string { + c.mu.Lock() + defer c.mu.Unlock() + buf := c.buffers[runID] + if buf == nil { + return "" + } + text := buf.String() + delete(c.buffers, runID) + return strings.TrimSpace(text) +} + +func (c *ChatSubscriber) clear(runID string) { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.buffers, runID) +} + +func (c *ChatSubscriber) speak(text string) { + text = strings.TrimSpace(text) + if text == "" || c.tts == nil { + return + } + c.tts.Speak(text) +} + +type chatPayload struct { + RunID string `json:"runId"` + SessionKey string `json:"sessionKey"` + State string `json:"state"` + Message *chatMessage `json:"message"` +} + +type chatMessage struct { + Role string `json:"role"` + Content []chatContent `json:"content"` +} + +type chatContent struct { + Type string `json:"type"` + Text string `json:"text"` +} + +func (p *chatPayload) text() string { + if p.Message == nil { + return "" + } + var b strings.Builder + for _, part := range p.Message.Content { + if strings.TrimSpace(part.Text) == "" { + continue + } + if b.Len() > 0 { + b.WriteString(" ") + } + b.WriteString(strings.TrimSpace(part.Text)) + } + return strings.TrimSpace(b.String()) +} + +func startMDNS(cfg NodeConfig, state *NodeState, logf func(string, ...any)) func() { + if !cfg.MDNSEnabled { + return nil + } + service := strings.TrimSpace(cfg.MDNSService) + if service == "" { + service = "_clawdis-node._tcp" + } + domain := strings.TrimSpace(cfg.MDNSDomain) + if domain == "" { + domain = "local." + } + name := strings.TrimSpace(cfg.MDNSName) + if name == "" { + name = strings.TrimSpace(state.DisplayName) + } + if name == "" { + name = defaultDisplayName() + } + if !strings.Contains(strings.ToLower(name), "clawdis") { + name = fmt.Sprintf("%s (Clawdis)", name) + } + hostLabel := strings.TrimSpace(defaultDisplayName()) + hostLabel = strings.TrimSuffix(hostLabel, ".local") + if strings.Contains(hostLabel, ".") { + parts := strings.Split(hostLabel, ".") + if len(parts) > 0 { + hostLabel = parts[0] + } + } + if hostLabel == "" { + hostLabel = "clawdis-node" + } + bridgeHost, bridgePort := parseBridgeAddr(cfg.BridgeAddr) + txt := []string{fmt.Sprintf("role=%s", "node"), fmt.Sprintf("displayName=%s", name), fmt.Sprintf("lanHost=%s.local", hostLabel), fmt.Sprintf("nodeId=%s", state.NodeID), "transport=node"} + if bridgeHost != "" { + txt = append(txt, fmt.Sprintf("bridgeHost=%s", bridgeHost)) + } + if bridgePort > 0 { + txt = append(txt, fmt.Sprintf("bridgePort=%d", bridgePort)) + } + if cfg.Platform != "" { + txt = append(txt, fmt.Sprintf("platform=%s", cfg.Platform)) + } + if cfg.Version != "" { + txt = append(txt, fmt.Sprintf("version=%s", cfg.Version)) + } + if cfg.DeviceFamily != "" { + txt = append(txt, fmt.Sprintf("deviceFamily=%s", cfg.DeviceFamily)) + } + if cfg.ModelIdentifier != "" { + txt = append(txt, fmt.Sprintf("modelIdentifier=%s", cfg.ModelIdentifier)) + } + + listener, err := net.Listen("tcp", "0.0.0.0:0") + if err != nil { + logf("mdns listen failed: %v", err) + return nil + } + port := listener.Addr().(*net.TCPAddr).Port + server, err := zeroconf.Register(name, service, domain, port, txt, nil) + if err != nil { + _ = listener.Close() + logf("mdns register failed: %v", err) + return nil + } + logf("mdns: advertised %s on %s (%s) port=%d", name, service, domain, port) + return func() { + server.Shutdown() + _ = listener.Close() + } +} + +func parseBridgeAddr(addr string) (string, int) { + addr = strings.TrimSpace(addr) + if addr == "" { + return "", 0 + } + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + return addr, 0 + } + port, err := strconv.Atoi(portStr) + if err != nil { + return host, 0 + } + return host, port +} + +func defaultStatePath() string { + home, err := os.UserHomeDir() + if err != nil { + return "./linux-node.json" + } + return filepath.Join(home, ".clawdis", "linux-node.json") +} + +func loadOrInitState(cfg NodeConfig) (*NodeState, error) { + path := cfg.StatePath + if path == "" { + path = defaultStatePath() + } + data, err := os.ReadFile(path) + if err == nil { + var st NodeState + if err := json.Unmarshal(data, &st); err != nil { + return nil, err + } + if st.NodeID == "" { + st.NodeID = deriveNodeID() + } + if st.DisplayName == "" { + st.DisplayName = defaultDisplayName() + } + return &st, nil + } + if !errors.Is(err, os.ErrNotExist) { + return nil, err + } + st := &NodeState{ + NodeID: deriveNodeID(), + DisplayName: defaultDisplayName(), + } + if err := saveState(path, st); err != nil { + return nil, err + } + return st, nil +} + +func saveState(path string, st *NodeState) error { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0o700); err != nil { + return err + } + payload, err := json.MarshalIndent(st, "", " ") + if err != nil { + return err + } + return os.WriteFile(path, payload, 0o600) +} + +func deriveNodeID() string { + host := defaultDisplayName() + base := sanitizeID(host) + if mid := machineID(); mid != "" { + return fmt.Sprintf("%s-%s", base, mid[:8]) + } + return fmt.Sprintf("%s-%s", base, randomID(6)) +} + +func defaultDisplayName() string { + host, err := os.Hostname() + if err != nil || strings.TrimSpace(host) == "" { + return "linux-node" + } + return host +} + +func machineID() string { + data, err := os.ReadFile("/etc/machine-id") + if err != nil { + return "" + } + id := strings.TrimSpace(string(data)) + id = strings.ReplaceAll(id, "-", "") + if len(id) < 8 { + return "" + } + return id +} + +func sanitizeID(input string) string { + input = strings.ToLower(strings.TrimSpace(input)) + if input == "" { + return "node" + } + var b strings.Builder + for _, r := range input { + if (r >= 97 && r <= 122) || (r >= 48 && r <= 57) { + b.WriteRune(r) + continue + } + if r == 45 || r == 95 { + b.WriteRune(45) + continue + } + b.WriteRune(45) + } + out := strings.Trim(b.String(), "-") + if out == "" { + return "node" + } + return out +} + +func randomID(n int) string { + if n <= 0 { + n = 6 + } + buf := make([]byte, n) + if _, err := rand.Read(buf); err != nil { + return fmt.Sprintf("%d", time.Now().UnixNano()) + } + return hex.EncodeToString(buf) +} + +func splitCSV(value string) []string { + value = strings.TrimSpace(value) + if value == "" { + return nil + } + parts := strings.Split(value, ",") + out := make([]string, 0, len(parts)) + for _, part := range parts { + item := strings.TrimSpace(part) + if item == "" { + continue + } + out = append(out, item) + } + return out +} + +func frameType(frame map[string]any) string { + if frame == nil { + return "" + } + if v, ok := frame["type"]; ok { + if s, ok := v.(string); ok { + return s + } + return fmt.Sprint(v) + } + return "" +} + +func frameString(frame map[string]any, key string) string { + if frame == nil { + return "" + } + if v, ok := frame[key]; ok { + if s, ok := v.(string); ok { + return s + } + if v != nil { + return fmt.Sprint(v) + } + } + return "" +} + +func fatal(err error) { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) +} + +func detectDeviceFamily() string { + data, err := os.ReadFile("/proc/cpuinfo") + if err != nil { + return "" + } + if strings.Contains(strings.ToLower(string(data)), "raspberry pi") { + return "raspi" + } + return "" +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..ab3562e --- /dev/null +++ b/go.mod @@ -0,0 +1,13 @@ +module github.com/clawdbot/clawgo + +go 1.22 + +require github.com/grandcat/zeroconf v1.0.0 + +require ( + github.com/cenkalti/backoff v2.2.1+incompatible // indirect + github.com/miekg/dns v1.1.27 // indirect + golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 // indirect + golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa // indirect + golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..82e6ab7 --- /dev/null +++ b/go.sum @@ -0,0 +1,26 @@ +github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4= +github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM= +github.com/grandcat/zeroconf v1.0.0 h1:uHhahLBKqwWBV6WZUDAT71044vwOTL+McW0mBJvo6kE= +github.com/grandcat/zeroconf v1.0.0/go.mod h1:lTKmG1zh86XyCoUeIHSA4FJMBwCJiQmGfcP2PdzytEs= +github.com/miekg/dns v1.1.27 h1:aEH/kqUzUxGJ/UHcEKdJY+ugH6WEzsEBBSPa8zuy1aM= +github.com/miekg/dns v1.1.27/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa h1:F+8P+gmewFQYRk6JoLQLwjBCTu3mcIURZfNkVweuRKA= +golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20190423024810-112230192c58 h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe h1:6fAMxZRR6sl1Uq8U61gxU+kPTs2tR8uOySCbBP7BN/M= +golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20191216052735-49a3e744a425/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/routing/policy/default/policy.go b/internal/routing/policy/default/policy.go new file mode 100644 index 0000000..e1dc58a --- /dev/null +++ b/internal/routing/policy/default/policy.go @@ -0,0 +1,80 @@ +package defaultpolicy + +import ( + "context" + "strings" + "unicode" + + "github.com/clawdbot/clawgo/internal/routing" +) + +type Policy struct { + cfg routing.Config + transport routing.Transport + logf func(string, ...any) +} + +func init() { + routing.Register("default", New) +} + +func New(cfg routing.Config, transport routing.Transport, logf func(string, ...any)) (routing.Router, error) { + if logf == nil { + logf = func(string, ...any) {} + } + return &Policy{cfg: cfg, transport: transport, logf: logf}, nil +} + +func (p *Policy) HandleTranscript(_ context.Context, text string) (bool, error) { + if strings.TrimSpace(text) == "" { + return false, nil + } + if p.cfg.QuickActions { + if handled, err := p.handleQuickActions(text); handled { + return true, err + } + } + if p.cfg.AgentRequest { + return true, p.transport.SendAgentRequest(p.cfg.SessionKey, text, p.cfg.Deliver, p.cfg.DeliverChannel, p.cfg.DeliverTo) + } + return true, p.transport.SendVoiceTranscript(p.cfg.SessionKey, text) +} + +func (p *Policy) handleQuickActions(text string) (bool, error) { + if isTelegramPing(text) { + if p.cfg.DeliverChannel != "telegram" || p.cfg.DeliverTo == "" { + p.logf("quick action skipped: telegram delivery not configured") + return true, nil + } + message := strings.TrimSpace(p.cfg.QuickPingMessage) + if message == "" { + message = "Ping." + } + err := p.transport.SendProviderMessage("telegram", p.cfg.DeliverTo, message) + if err != nil { + p.logf("quick action send failed: %v", err) + } else { + p.logf("quick action: telegram ping sent") + } + return true, err + } + return false, nil +} + +func isTelegramPing(text string) bool { + normalized := normalizeCommand(text) + return strings.Contains(normalized, "telegram") && strings.Contains(normalized, "ping") +} + +func normalizeCommand(text string) string { + lowered := strings.ToLower(text) + var b strings.Builder + for _, r := range lowered { + if unicode.IsLetter(r) || unicode.IsNumber(r) { + b.WriteRune(r) + continue + } + b.WriteRune(32) + } + return strings.Join(strings.Fields(b.String()), " ") +} diff --git a/internal/routing/policy/default/policy_test.go b/internal/routing/policy/default/policy_test.go new file mode 100644 index 0000000..b954488 --- /dev/null +++ b/internal/routing/policy/default/policy_test.go @@ -0,0 +1,110 @@ +package defaultpolicy + +import ( + "context" + "testing" + + "github.com/clawdbot/clawgo/internal/routing" +) + +type call struct { + sessionKey string + text string + deliver bool + channel string + to string + provider string + message string +} + +type fakeTransport struct { + voice []call + agent []call + provider []call +} + +func (f *fakeTransport) SendVoiceTranscript(sessionKey, text string) error { + f.voice = append(f.voice, call{sessionKey: sessionKey, text: text}) + return nil +} + +func (f *fakeTransport) SendAgentRequest(sessionKey, text string, deliver bool, channel, to string) error { + f.agent = append(f.agent, call{sessionKey: sessionKey, text: text, deliver: deliver, channel: channel, to: to}) + return nil +} + +func (f *fakeTransport) SendProviderMessage(provider, to, message string) error { + f.provider = append(f.provider, call{provider: provider, to: to, message: message}) + return nil +} + +func TestDefaultPolicyQuickAction(t *testing.T) { + transport := &fakeTransport{} + cfg := routing.Config{ + QuickActions: true, + DeliverChannel: "telegram", + DeliverTo: "123", + QuickPingMessage: "Ping.", + } + policy, err := New(cfg, transport, func(string, ...any) {}) + if err != nil { + t.Fatalf("new policy: %v", err) + } + handled, err := policy.HandleTranscript(context.Background(), "hey razor ping me on telegram") + if err != nil { + t.Fatalf("handle: %v", err) + } + if !handled { + t.Fatalf("expected handled") + } + if len(transport.provider) != 1 { + t.Fatalf("expected provider message") + } + if len(transport.voice) != 0 || len(transport.agent) != 0 { + t.Fatalf("unexpected fallback sends") + } +} + +func TestDefaultPolicyAgentRequest(t *testing.T) { + transport := &fakeTransport{} + cfg := routing.Config{ + SessionKey: "main", + AgentRequest: true, + Deliver: true, + DeliverChannel: "telegram", + DeliverTo: "123", + } + policy, err := New(cfg, transport, func(string, ...any) {}) + if err != nil { + t.Fatalf("new policy: %v", err) + } + _, err = policy.HandleTranscript(context.Background(), "hello") + if err != nil { + t.Fatalf("handle: %v", err) + } + if len(transport.agent) != 1 { + t.Fatalf("expected agent request") + } + if len(transport.voice) != 0 || len(transport.provider) != 0 { + t.Fatalf("unexpected fallback sends") + } +} + +func TestDefaultPolicyVoiceTranscript(t *testing.T) { + transport := &fakeTransport{} + cfg := routing.Config{SessionKey: "main"} + policy, err := New(cfg, transport, func(string, ...any) {}) + if err != nil { + t.Fatalf("new policy: %v", err) + } + _, err = policy.HandleTranscript(context.Background(), "hello") + if err != nil { + t.Fatalf("handle: %v", err) + } + if len(transport.voice) != 1 { + t.Fatalf("expected voice transcript") + } + if len(transport.agent) != 0 || len(transport.provider) != 0 { + t.Fatalf("unexpected fallback sends") + } +} diff --git a/internal/routing/routing.go b/internal/routing/routing.go new file mode 100644 index 0000000..3f969f8 --- /dev/null +++ b/internal/routing/routing.go @@ -0,0 +1,61 @@ +package routing + +import ( + "context" + "fmt" + "strings" + "sync" +) + +type Transport interface { + SendVoiceTranscript(sessionKey, text string) error + SendAgentRequest(sessionKey, text string, deliver bool, channel, to string) error + SendProviderMessage(provider, to, message string) error +} + +type Router interface { + HandleTranscript(ctx context.Context, text string) (bool, error) +} + +type Config struct { + SessionKey string + AgentRequest bool + Deliver bool + DeliverChannel string + DeliverTo string + QuickActions bool + QuickPingMessage string +} + +type Factory func(cfg Config, transport Transport, logf func(string, ...any)) (Router, error) + +type registry struct { + mu sync.RWMutex + factories map[string]Factory +} + +var globalRegistry = ®istry{factories: map[string]Factory{}} + +func Register(name string, factory Factory) { + name = strings.TrimSpace(name) + if name == "" || factory == nil { + return + } + globalRegistry.mu.Lock() + defer globalRegistry.mu.Unlock() + globalRegistry.factories[name] = factory +} + +func New(name string, cfg Config, transport Transport, logf func(string, ...any)) (Router, error) { + name = strings.TrimSpace(name) + if name == "" { + name = "default" + } + globalRegistry.mu.RLock() + factory, ok := globalRegistry.factories[name] + globalRegistry.mu.RUnlock() + if !ok { + return nil, fmt.Errorf("routing plugin not found: %s", name) + } + return factory(cfg, transport, logf) +}