fix(code): complete portal websocket handshake

This commit is contained in:
Vincent Koc 2026-05-05 01:45:43 -07:00
parent d6ac429cf7
commit 51d505aee4
No known key found for this signature in database
4 changed files with 140 additions and 12 deletions

View File

@ -39,7 +39,10 @@ type codeProxyMessage struct {
Frame string `json:"frame,omitempty"`
}
const maxCodeBridgeBodyChunkBytes = 63 * 1024
const (
maxCodeBridgeBodyChunkBytes = 63 * 1024
maxPendingCodeBridgeWebSocketMessages = 32
)
func (a App) webCode(ctx context.Context, args []string) error {
defaults := defaultConfig()
@ -191,6 +194,7 @@ type codeBridge struct {
mu sync.Mutex
writeMu sync.Mutex
upstream map[string]*websocket.Conn
pending map[string][]codeProxyMessage
}
func connectCodeBridge(ctx context.Context, coord *CoordinatorClient, leaseID, host, port string) (*codeBridge, error) {
@ -212,6 +216,7 @@ func connectCodeBridge(ctx context.Context, coord *CoordinatorClient, leaseID, h
},
debug: os.Getenv("CRABBOX_CODE_DEBUG") == "1",
upstream: map[string]*websocket.Conn{},
pending: map[string][]codeProxyMessage{},
}, nil
}
@ -252,6 +257,9 @@ func (b *codeBridge) Close(code websocket.StatusCode, reason string) {
_ = conn.Close(websocket.StatusNormalClosure, "bridge stopped")
delete(b.upstream, id)
}
for id := range b.pending {
delete(b.pending, id)
}
}
func (b *codeBridge) handleHTTP(ctx context.Context, msg codeProxyMessage) {
@ -323,11 +331,12 @@ func (b *codeBridge) handleHTTP(ctx context.Context, msg codeProxyMessage) {
func (b *codeBridge) openUpstreamWebSocket(ctx context.Context, msg codeProxyMessage) {
upstream := "ws" + strings.TrimPrefix(b.baseURL, "http") + codeUpstreamPath(msg.Path)
b.trace("ws_open id=%s path=%s upstream=%s", msg.ID, msg.Path, upstream)
header := http.Header{}
for key, value := range msg.Headers {
header.Set(key, value)
}
conn, _, err := websocket.Dial(ctx, upstream, &websocket.DialOptions{HTTPHeader: header})
header, subprotocols := codeWebSocketDialHeaders(b.baseURL, msg.Headers)
b.trace("ws_open_headers id=%s cookie=%t origin=%q subprotocols=%d", msg.ID, header.Get("Cookie") != "", header.Get("Origin"), len(subprotocols))
conn, _, err := websocket.Dial(ctx, upstream, &websocket.DialOptions{
HTTPHeader: header,
Subprotocols: subprotocols,
})
if err != nil {
b.trace("ws_open_error id=%s error=%v", msg.ID, err)
_ = b.writeJSON(ctx, codeProxyMessage{Type: "ws_close", ID: msg.ID, Code: int(websocket.StatusInternalError), Reason: err.Error()})
@ -335,7 +344,18 @@ func (b *codeBridge) openUpstreamWebSocket(ctx context.Context, msg codeProxyMes
}
b.mu.Lock()
b.upstream[msg.ID] = conn
pending := append([]codeProxyMessage(nil), b.pending[msg.ID]...)
delete(b.pending, msg.ID)
b.mu.Unlock()
b.trace("ws_open_ok id=%s subprotocols=%d pending=%d", msg.ID, len(subprotocols), len(pending))
for _, pendingMessage := range pending {
if err := b.writeUpstreamFrame(ctx, conn, pendingMessage); err != nil {
b.trace("ws_pending_write_error id=%s error=%v", msg.ID, err)
b.closeUpstreamWebSocket(msg.ID, websocket.StatusInternalError, err.Error())
_ = b.writeJSON(ctx, codeProxyMessage{Type: "ws_close", ID: msg.ID, Code: int(websocket.StatusInternalError), Reason: err.Error()})
return
}
}
go b.readUpstreamWebSocket(ctx, msg.ID, conn)
}
@ -361,16 +381,37 @@ func (b *codeBridge) readUpstreamWebSocket(ctx context.Context, id string, conn
}
func (b *codeBridge) writeUpstreamWebSocket(ctx context.Context, msg codeProxyMessage) {
data, _ := base64.StdEncoding.DecodeString(msg.Body)
b.mu.Lock()
conn := b.upstream[msg.ID]
b.mu.Unlock()
if conn == nil {
b.trace("ws_downstream_missing id=%s frame=%s bytes=%d", msg.ID, msg.Frame, len(data))
pending := b.pending[msg.ID]
if len(pending) >= maxPendingCodeBridgeWebSocketMessages {
b.mu.Unlock()
b.trace("ws_downstream_drop id=%s frame=%s pending=%d", msg.ID, msg.Frame, len(pending))
_ = b.writeJSON(ctx, codeProxyMessage{Type: "ws_close", ID: msg.ID, Code: int(websocket.StatusPolicyViolation), Reason: "too many pending websocket messages"})
return
}
b.pending[msg.ID] = append(pending, msg)
b.mu.Unlock()
b.trace("ws_downstream_buffered id=%s frame=%s pending=%d", msg.ID, msg.Frame, len(pending)+1)
return
}
b.trace("ws_downstream_data id=%s frame=%s bytes=%d", msg.ID, websocketMessageType(msg.Frame), len(data))
_ = conn.Write(ctx, websocketMessageType(msg.Frame), data)
b.mu.Unlock()
if err := b.writeUpstreamFrame(ctx, conn, msg); err != nil {
b.trace("ws_downstream_write_error id=%s error=%v", msg.ID, err)
b.closeUpstreamWebSocket(msg.ID, websocket.StatusInternalError, err.Error())
_ = b.writeJSON(ctx, codeProxyMessage{Type: "ws_close", ID: msg.ID, Code: int(websocket.StatusInternalError), Reason: err.Error()})
}
}
func (b *codeBridge) writeUpstreamFrame(ctx context.Context, conn *websocket.Conn, msg codeProxyMessage) error {
data, err := base64.StdEncoding.DecodeString(msg.Body)
if err != nil {
return err
}
frameType := websocketMessageType(msg.Frame)
b.trace("ws_downstream_data id=%s frame=%s bytes=%d", msg.ID, codeFrameType(frameType), len(data))
return conn.Write(ctx, frameType, data)
}
func (b *codeBridge) closeUpstreamWebSocket(id string, code websocket.StatusCode, reason string) {
@ -380,6 +421,7 @@ func (b *codeBridge) closeUpstreamWebSocket(id string, code websocket.StatusCode
b.mu.Lock()
conn := b.upstream[id]
delete(b.upstream, id)
delete(b.pending, id)
b.mu.Unlock()
if conn != nil {
_ = conn.Close(code, reason)
@ -475,6 +517,32 @@ func websocketMessageType(frame string) websocket.MessageType {
return websocket.MessageBinary
}
func codeWebSocketDialHeaders(baseURL string, values map[string]string) (http.Header, []string) {
headers := http.Header{}
for key, value := range values {
headers.Set(key, value)
}
subprotocols := websocketSubprotocols(headers)
headers.Del("Sec-WebSocket-Protocol")
if headers.Get("Origin") != "" {
headers.Set("Origin", baseURL)
}
return headers, subprotocols
}
func websocketSubprotocols(headers http.Header) []string {
var out []string
for _, value := range headers.Values("Sec-WebSocket-Protocol") {
for _, part := range strings.Split(value, ",") {
protocol := strings.TrimSpace(part)
if protocol != "" {
out = append(out, protocol)
}
}
}
return out
}
func availableLocalCodePort() string {
for port := 8081; port <= 8180; port++ {
ln, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))

View File

@ -1,6 +1,7 @@
package cli
import (
"net/http"
"strings"
"testing"
@ -111,3 +112,35 @@ func TestCodeFrameType(t *testing.T) {
t.Fatalf("websocketMessageType default=%v", got)
}
}
func TestWebSocketSubprotocols(t *testing.T) {
headers := http.Header{}
headers.Add("Sec-WebSocket-Protocol", "vscode-remote, crabbox")
headers.Add("Sec-WebSocket-Protocol", " second-token ")
got := websocketSubprotocols(headers)
want := []string{"vscode-remote", "crabbox", "second-token"}
if strings.Join(got, "|") != strings.Join(want, "|") {
t.Fatalf("websocketSubprotocols=%q want %q", got, want)
}
}
func TestCodeWebSocketDialHeadersRewritesOrigin(t *testing.T) {
headers, subprotocols := codeWebSocketDialHeaders("http://127.0.0.1:8081", map[string]string{
"cookie": "vscode-tkn=remote-token",
"origin": "https://crabbox.openclaw.ai",
"sec-websocket-protocol": "proto-a, proto-b",
})
if headers.Get("Origin") != "http://127.0.0.1:8081" {
t.Fatalf("origin=%q", headers.Get("Origin"))
}
if headers.Get("Cookie") != "vscode-tkn=remote-token" {
t.Fatalf("cookie=%q", headers.Get("Cookie"))
}
if headers.Get("Sec-WebSocket-Protocol") != "" {
t.Fatalf("raw subprotocol header should be removed: %q", headers.Get("Sec-WebSocket-Protocol"))
}
if strings.Join(subprotocols, "|") != "proto-a|proto-b" {
t.Fatalf("subprotocols=%q", subprotocols)
}
}

View File

@ -1840,7 +1840,7 @@ function codeLeaseError(lease: LeaseRecord): string {
return "";
}
function codeForwardHeaders(headers: Headers): Record<string, string> {
export function codeForwardHeaders(headers: Headers): Record<string, string> {
const out: Record<string, string> = {};
const allowed = new Set([
"accept",
@ -1856,11 +1856,24 @@ function codeForwardHeaders(headers: Headers): Record<string, string> {
const lower = key.toLowerCase();
if (allowed.has(lower) || lower.startsWith("x-")) {
out[lower] = value;
} else if (lower === "cookie") {
const cookie = codeForwardCookie(value);
if (cookie) {
out["cookie"] = cookie;
}
}
}
return out;
}
function codeForwardCookie(value: string): string | undefined {
const tokens = value
.split(";")
.map((part) => part.trim())
.filter((part) => part.startsWith("vscode-tkn="));
return tokens.length > 0 ? tokens.join("; ") : undefined;
}
const codePortalContentSecurityPolicy = [
"default-src 'self'",
"base-uri 'self'",

View File

@ -2,6 +2,7 @@ import { afterEach, describe, expect, it, vi } from "vitest";
import {
FleetDurableObject,
codeForwardHeaders,
codeResponseHeaders,
flushPendingWebVNC,
forwardOrBufferWebVNC,
@ -585,6 +586,19 @@ describe("fleet lease identity and idle", () => {
expect(headers.get("cache-control")).toBe("no-store, no-transform");
});
it("forwards only the VS Code token cookie to code-server", () => {
const headers = codeForwardHeaders(
new Headers({
cookie: "crabbox_session=secret; vscode-tkn=remote-token; other=value",
origin: "https://crabbox.openclaw.ai",
}),
);
expect(headers["cookie"]).toBe("vscode-tkn=remote-token");
expect(headers["cookie"]).not.toContain("crabbox_session");
expect(headers.origin).toBe("https://crabbox.openclaw.ai");
});
it("serves WebVNC pages only for desktop leases and requires an agent upgrade", async () => {
const storage = new MemoryStorage();
const fleet = testFleet(storage);