diff --git a/internal/cli/code.go b/internal/cli/code.go index eae0668..af71410 100644 --- a/internal/cli/code.go +++ b/internal/cli/code.go @@ -39,7 +39,10 @@ type codeProxyMessage struct { Frame string `json:"frame,omitempty"` } -const maxCodeBridgeBodyChunkBytes = 63 * 1024 +const ( + maxCodeBridgeBodyChunkBytes = 63 * 1024 + maxPendingCodeBridgeWebSocketMessages = 32 +) func (a App) webCode(ctx context.Context, args []string) error { defaults := defaultConfig() @@ -191,6 +194,7 @@ type codeBridge struct { mu sync.Mutex writeMu sync.Mutex upstream map[string]*websocket.Conn + pending map[string][]codeProxyMessage } func connectCodeBridge(ctx context.Context, coord *CoordinatorClient, leaseID, host, port string) (*codeBridge, error) { @@ -212,6 +216,7 @@ func connectCodeBridge(ctx context.Context, coord *CoordinatorClient, leaseID, h }, debug: os.Getenv("CRABBOX_CODE_DEBUG") == "1", upstream: map[string]*websocket.Conn{}, + pending: map[string][]codeProxyMessage{}, }, nil } @@ -252,6 +257,9 @@ func (b *codeBridge) Close(code websocket.StatusCode, reason string) { _ = conn.Close(websocket.StatusNormalClosure, "bridge stopped") delete(b.upstream, id) } + for id := range b.pending { + delete(b.pending, id) + } } func (b *codeBridge) handleHTTP(ctx context.Context, msg codeProxyMessage) { @@ -323,11 +331,12 @@ func (b *codeBridge) handleHTTP(ctx context.Context, msg codeProxyMessage) { func (b *codeBridge) openUpstreamWebSocket(ctx context.Context, msg codeProxyMessage) { upstream := "ws" + strings.TrimPrefix(b.baseURL, "http") + codeUpstreamPath(msg.Path) b.trace("ws_open id=%s path=%s upstream=%s", msg.ID, msg.Path, upstream) - header := http.Header{} - for key, value := range msg.Headers { - header.Set(key, value) - } - conn, _, err := websocket.Dial(ctx, upstream, &websocket.DialOptions{HTTPHeader: header}) + header, subprotocols := codeWebSocketDialHeaders(b.baseURL, msg.Headers) + b.trace("ws_open_headers id=%s cookie=%t origin=%q subprotocols=%d", msg.ID, header.Get("Cookie") != "", header.Get("Origin"), len(subprotocols)) + conn, _, err := websocket.Dial(ctx, upstream, &websocket.DialOptions{ + HTTPHeader: header, + Subprotocols: subprotocols, + }) if err != nil { b.trace("ws_open_error id=%s error=%v", msg.ID, err) _ = b.writeJSON(ctx, codeProxyMessage{Type: "ws_close", ID: msg.ID, Code: int(websocket.StatusInternalError), Reason: err.Error()}) @@ -335,7 +344,18 @@ func (b *codeBridge) openUpstreamWebSocket(ctx context.Context, msg codeProxyMes } b.mu.Lock() b.upstream[msg.ID] = conn + pending := append([]codeProxyMessage(nil), b.pending[msg.ID]...) + delete(b.pending, msg.ID) b.mu.Unlock() + b.trace("ws_open_ok id=%s subprotocols=%d pending=%d", msg.ID, len(subprotocols), len(pending)) + for _, pendingMessage := range pending { + if err := b.writeUpstreamFrame(ctx, conn, pendingMessage); err != nil { + b.trace("ws_pending_write_error id=%s error=%v", msg.ID, err) + b.closeUpstreamWebSocket(msg.ID, websocket.StatusInternalError, err.Error()) + _ = b.writeJSON(ctx, codeProxyMessage{Type: "ws_close", ID: msg.ID, Code: int(websocket.StatusInternalError), Reason: err.Error()}) + return + } + } go b.readUpstreamWebSocket(ctx, msg.ID, conn) } @@ -361,16 +381,37 @@ func (b *codeBridge) readUpstreamWebSocket(ctx context.Context, id string, conn } func (b *codeBridge) writeUpstreamWebSocket(ctx context.Context, msg codeProxyMessage) { - data, _ := base64.StdEncoding.DecodeString(msg.Body) b.mu.Lock() conn := b.upstream[msg.ID] - b.mu.Unlock() if conn == nil { - b.trace("ws_downstream_missing id=%s frame=%s bytes=%d", msg.ID, msg.Frame, len(data)) + pending := b.pending[msg.ID] + if len(pending) >= maxPendingCodeBridgeWebSocketMessages { + b.mu.Unlock() + b.trace("ws_downstream_drop id=%s frame=%s pending=%d", msg.ID, msg.Frame, len(pending)) + _ = b.writeJSON(ctx, codeProxyMessage{Type: "ws_close", ID: msg.ID, Code: int(websocket.StatusPolicyViolation), Reason: "too many pending websocket messages"}) + return + } + b.pending[msg.ID] = append(pending, msg) + b.mu.Unlock() + b.trace("ws_downstream_buffered id=%s frame=%s pending=%d", msg.ID, msg.Frame, len(pending)+1) return } - b.trace("ws_downstream_data id=%s frame=%s bytes=%d", msg.ID, websocketMessageType(msg.Frame), len(data)) - _ = conn.Write(ctx, websocketMessageType(msg.Frame), data) + b.mu.Unlock() + if err := b.writeUpstreamFrame(ctx, conn, msg); err != nil { + b.trace("ws_downstream_write_error id=%s error=%v", msg.ID, err) + b.closeUpstreamWebSocket(msg.ID, websocket.StatusInternalError, err.Error()) + _ = b.writeJSON(ctx, codeProxyMessage{Type: "ws_close", ID: msg.ID, Code: int(websocket.StatusInternalError), Reason: err.Error()}) + } +} + +func (b *codeBridge) writeUpstreamFrame(ctx context.Context, conn *websocket.Conn, msg codeProxyMessage) error { + data, err := base64.StdEncoding.DecodeString(msg.Body) + if err != nil { + return err + } + frameType := websocketMessageType(msg.Frame) + b.trace("ws_downstream_data id=%s frame=%s bytes=%d", msg.ID, codeFrameType(frameType), len(data)) + return conn.Write(ctx, frameType, data) } func (b *codeBridge) closeUpstreamWebSocket(id string, code websocket.StatusCode, reason string) { @@ -380,6 +421,7 @@ func (b *codeBridge) closeUpstreamWebSocket(id string, code websocket.StatusCode b.mu.Lock() conn := b.upstream[id] delete(b.upstream, id) + delete(b.pending, id) b.mu.Unlock() if conn != nil { _ = conn.Close(code, reason) @@ -475,6 +517,32 @@ func websocketMessageType(frame string) websocket.MessageType { return websocket.MessageBinary } +func codeWebSocketDialHeaders(baseURL string, values map[string]string) (http.Header, []string) { + headers := http.Header{} + for key, value := range values { + headers.Set(key, value) + } + subprotocols := websocketSubprotocols(headers) + headers.Del("Sec-WebSocket-Protocol") + if headers.Get("Origin") != "" { + headers.Set("Origin", baseURL) + } + return headers, subprotocols +} + +func websocketSubprotocols(headers http.Header) []string { + var out []string + for _, value := range headers.Values("Sec-WebSocket-Protocol") { + for _, part := range strings.Split(value, ",") { + protocol := strings.TrimSpace(part) + if protocol != "" { + out = append(out, protocol) + } + } + } + return out +} + func availableLocalCodePort() string { for port := 8081; port <= 8180; port++ { ln, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port)) diff --git a/internal/cli/code_test.go b/internal/cli/code_test.go index 026d85c..e43d932 100644 --- a/internal/cli/code_test.go +++ b/internal/cli/code_test.go @@ -1,6 +1,7 @@ package cli import ( + "net/http" "strings" "testing" @@ -111,3 +112,35 @@ func TestCodeFrameType(t *testing.T) { t.Fatalf("websocketMessageType default=%v", got) } } + +func TestWebSocketSubprotocols(t *testing.T) { + headers := http.Header{} + headers.Add("Sec-WebSocket-Protocol", "vscode-remote, crabbox") + headers.Add("Sec-WebSocket-Protocol", " second-token ") + got := websocketSubprotocols(headers) + want := []string{"vscode-remote", "crabbox", "second-token"} + if strings.Join(got, "|") != strings.Join(want, "|") { + t.Fatalf("websocketSubprotocols=%q want %q", got, want) + } +} + +func TestCodeWebSocketDialHeadersRewritesOrigin(t *testing.T) { + headers, subprotocols := codeWebSocketDialHeaders("http://127.0.0.1:8081", map[string]string{ + "cookie": "vscode-tkn=remote-token", + "origin": "https://crabbox.openclaw.ai", + "sec-websocket-protocol": "proto-a, proto-b", + }) + + if headers.Get("Origin") != "http://127.0.0.1:8081" { + t.Fatalf("origin=%q", headers.Get("Origin")) + } + if headers.Get("Cookie") != "vscode-tkn=remote-token" { + t.Fatalf("cookie=%q", headers.Get("Cookie")) + } + if headers.Get("Sec-WebSocket-Protocol") != "" { + t.Fatalf("raw subprotocol header should be removed: %q", headers.Get("Sec-WebSocket-Protocol")) + } + if strings.Join(subprotocols, "|") != "proto-a|proto-b" { + t.Fatalf("subprotocols=%q", subprotocols) + } +} diff --git a/worker/src/fleet.ts b/worker/src/fleet.ts index 56c30b9..b3447fe 100644 --- a/worker/src/fleet.ts +++ b/worker/src/fleet.ts @@ -1840,7 +1840,7 @@ function codeLeaseError(lease: LeaseRecord): string { return ""; } -function codeForwardHeaders(headers: Headers): Record { +export function codeForwardHeaders(headers: Headers): Record { const out: Record = {}; const allowed = new Set([ "accept", @@ -1856,11 +1856,24 @@ function codeForwardHeaders(headers: Headers): Record { const lower = key.toLowerCase(); if (allowed.has(lower) || lower.startsWith("x-")) { out[lower] = value; + } else if (lower === "cookie") { + const cookie = codeForwardCookie(value); + if (cookie) { + out["cookie"] = cookie; + } } } return out; } +function codeForwardCookie(value: string): string | undefined { + const tokens = value + .split(";") + .map((part) => part.trim()) + .filter((part) => part.startsWith("vscode-tkn=")); + return tokens.length > 0 ? tokens.join("; ") : undefined; +} + const codePortalContentSecurityPolicy = [ "default-src 'self'", "base-uri 'self'", diff --git a/worker/test/fleet.test.ts b/worker/test/fleet.test.ts index 7c70ef5..b1c03be 100644 --- a/worker/test/fleet.test.ts +++ b/worker/test/fleet.test.ts @@ -2,6 +2,7 @@ import { afterEach, describe, expect, it, vi } from "vitest"; import { FleetDurableObject, + codeForwardHeaders, codeResponseHeaders, flushPendingWebVNC, forwardOrBufferWebVNC, @@ -585,6 +586,19 @@ describe("fleet lease identity and idle", () => { expect(headers.get("cache-control")).toBe("no-store, no-transform"); }); + it("forwards only the VS Code token cookie to code-server", () => { + const headers = codeForwardHeaders( + new Headers({ + cookie: "crabbox_session=secret; vscode-tkn=remote-token; other=value", + origin: "https://crabbox.openclaw.ai", + }), + ); + + expect(headers["cookie"]).toBe("vscode-tkn=remote-token"); + expect(headers["cookie"]).not.toContain("crabbox_session"); + expect(headers.origin).toBe("https://crabbox.openclaw.ai"); + }); + it("serves WebVNC pages only for desktop leases and requires an agent upgrade", async () => { const storage = new MemoryStorage(); const fleet = testFleet(storage);