fix(code): complete portal websocket handshake
This commit is contained in:
parent
d6ac429cf7
commit
51d505aee4
@ -39,7 +39,10 @@ type codeProxyMessage struct {
|
||||
Frame string `json:"frame,omitempty"`
|
||||
}
|
||||
|
||||
const maxCodeBridgeBodyChunkBytes = 63 * 1024
|
||||
const (
|
||||
maxCodeBridgeBodyChunkBytes = 63 * 1024
|
||||
maxPendingCodeBridgeWebSocketMessages = 32
|
||||
)
|
||||
|
||||
func (a App) webCode(ctx context.Context, args []string) error {
|
||||
defaults := defaultConfig()
|
||||
@ -191,6 +194,7 @@ type codeBridge struct {
|
||||
mu sync.Mutex
|
||||
writeMu sync.Mutex
|
||||
upstream map[string]*websocket.Conn
|
||||
pending map[string][]codeProxyMessage
|
||||
}
|
||||
|
||||
func connectCodeBridge(ctx context.Context, coord *CoordinatorClient, leaseID, host, port string) (*codeBridge, error) {
|
||||
@ -212,6 +216,7 @@ func connectCodeBridge(ctx context.Context, coord *CoordinatorClient, leaseID, h
|
||||
},
|
||||
debug: os.Getenv("CRABBOX_CODE_DEBUG") == "1",
|
||||
upstream: map[string]*websocket.Conn{},
|
||||
pending: map[string][]codeProxyMessage{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -252,6 +257,9 @@ func (b *codeBridge) Close(code websocket.StatusCode, reason string) {
|
||||
_ = conn.Close(websocket.StatusNormalClosure, "bridge stopped")
|
||||
delete(b.upstream, id)
|
||||
}
|
||||
for id := range b.pending {
|
||||
delete(b.pending, id)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *codeBridge) handleHTTP(ctx context.Context, msg codeProxyMessage) {
|
||||
@ -323,11 +331,12 @@ func (b *codeBridge) handleHTTP(ctx context.Context, msg codeProxyMessage) {
|
||||
func (b *codeBridge) openUpstreamWebSocket(ctx context.Context, msg codeProxyMessage) {
|
||||
upstream := "ws" + strings.TrimPrefix(b.baseURL, "http") + codeUpstreamPath(msg.Path)
|
||||
b.trace("ws_open id=%s path=%s upstream=%s", msg.ID, msg.Path, upstream)
|
||||
header := http.Header{}
|
||||
for key, value := range msg.Headers {
|
||||
header.Set(key, value)
|
||||
}
|
||||
conn, _, err := websocket.Dial(ctx, upstream, &websocket.DialOptions{HTTPHeader: header})
|
||||
header, subprotocols := codeWebSocketDialHeaders(b.baseURL, msg.Headers)
|
||||
b.trace("ws_open_headers id=%s cookie=%t origin=%q subprotocols=%d", msg.ID, header.Get("Cookie") != "", header.Get("Origin"), len(subprotocols))
|
||||
conn, _, err := websocket.Dial(ctx, upstream, &websocket.DialOptions{
|
||||
HTTPHeader: header,
|
||||
Subprotocols: subprotocols,
|
||||
})
|
||||
if err != nil {
|
||||
b.trace("ws_open_error id=%s error=%v", msg.ID, err)
|
||||
_ = b.writeJSON(ctx, codeProxyMessage{Type: "ws_close", ID: msg.ID, Code: int(websocket.StatusInternalError), Reason: err.Error()})
|
||||
@ -335,7 +344,18 @@ func (b *codeBridge) openUpstreamWebSocket(ctx context.Context, msg codeProxyMes
|
||||
}
|
||||
b.mu.Lock()
|
||||
b.upstream[msg.ID] = conn
|
||||
pending := append([]codeProxyMessage(nil), b.pending[msg.ID]...)
|
||||
delete(b.pending, msg.ID)
|
||||
b.mu.Unlock()
|
||||
b.trace("ws_open_ok id=%s subprotocols=%d pending=%d", msg.ID, len(subprotocols), len(pending))
|
||||
for _, pendingMessage := range pending {
|
||||
if err := b.writeUpstreamFrame(ctx, conn, pendingMessage); err != nil {
|
||||
b.trace("ws_pending_write_error id=%s error=%v", msg.ID, err)
|
||||
b.closeUpstreamWebSocket(msg.ID, websocket.StatusInternalError, err.Error())
|
||||
_ = b.writeJSON(ctx, codeProxyMessage{Type: "ws_close", ID: msg.ID, Code: int(websocket.StatusInternalError), Reason: err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
go b.readUpstreamWebSocket(ctx, msg.ID, conn)
|
||||
}
|
||||
|
||||
@ -361,16 +381,37 @@ func (b *codeBridge) readUpstreamWebSocket(ctx context.Context, id string, conn
|
||||
}
|
||||
|
||||
func (b *codeBridge) writeUpstreamWebSocket(ctx context.Context, msg codeProxyMessage) {
|
||||
data, _ := base64.StdEncoding.DecodeString(msg.Body)
|
||||
b.mu.Lock()
|
||||
conn := b.upstream[msg.ID]
|
||||
b.mu.Unlock()
|
||||
if conn == nil {
|
||||
b.trace("ws_downstream_missing id=%s frame=%s bytes=%d", msg.ID, msg.Frame, len(data))
|
||||
pending := b.pending[msg.ID]
|
||||
if len(pending) >= maxPendingCodeBridgeWebSocketMessages {
|
||||
b.mu.Unlock()
|
||||
b.trace("ws_downstream_drop id=%s frame=%s pending=%d", msg.ID, msg.Frame, len(pending))
|
||||
_ = b.writeJSON(ctx, codeProxyMessage{Type: "ws_close", ID: msg.ID, Code: int(websocket.StatusPolicyViolation), Reason: "too many pending websocket messages"})
|
||||
return
|
||||
}
|
||||
b.pending[msg.ID] = append(pending, msg)
|
||||
b.mu.Unlock()
|
||||
b.trace("ws_downstream_buffered id=%s frame=%s pending=%d", msg.ID, msg.Frame, len(pending)+1)
|
||||
return
|
||||
}
|
||||
b.trace("ws_downstream_data id=%s frame=%s bytes=%d", msg.ID, websocketMessageType(msg.Frame), len(data))
|
||||
_ = conn.Write(ctx, websocketMessageType(msg.Frame), data)
|
||||
b.mu.Unlock()
|
||||
if err := b.writeUpstreamFrame(ctx, conn, msg); err != nil {
|
||||
b.trace("ws_downstream_write_error id=%s error=%v", msg.ID, err)
|
||||
b.closeUpstreamWebSocket(msg.ID, websocket.StatusInternalError, err.Error())
|
||||
_ = b.writeJSON(ctx, codeProxyMessage{Type: "ws_close", ID: msg.ID, Code: int(websocket.StatusInternalError), Reason: err.Error()})
|
||||
}
|
||||
}
|
||||
|
||||
func (b *codeBridge) writeUpstreamFrame(ctx context.Context, conn *websocket.Conn, msg codeProxyMessage) error {
|
||||
data, err := base64.StdEncoding.DecodeString(msg.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
frameType := websocketMessageType(msg.Frame)
|
||||
b.trace("ws_downstream_data id=%s frame=%s bytes=%d", msg.ID, codeFrameType(frameType), len(data))
|
||||
return conn.Write(ctx, frameType, data)
|
||||
}
|
||||
|
||||
func (b *codeBridge) closeUpstreamWebSocket(id string, code websocket.StatusCode, reason string) {
|
||||
@ -380,6 +421,7 @@ func (b *codeBridge) closeUpstreamWebSocket(id string, code websocket.StatusCode
|
||||
b.mu.Lock()
|
||||
conn := b.upstream[id]
|
||||
delete(b.upstream, id)
|
||||
delete(b.pending, id)
|
||||
b.mu.Unlock()
|
||||
if conn != nil {
|
||||
_ = conn.Close(code, reason)
|
||||
@ -475,6 +517,32 @@ func websocketMessageType(frame string) websocket.MessageType {
|
||||
return websocket.MessageBinary
|
||||
}
|
||||
|
||||
func codeWebSocketDialHeaders(baseURL string, values map[string]string) (http.Header, []string) {
|
||||
headers := http.Header{}
|
||||
for key, value := range values {
|
||||
headers.Set(key, value)
|
||||
}
|
||||
subprotocols := websocketSubprotocols(headers)
|
||||
headers.Del("Sec-WebSocket-Protocol")
|
||||
if headers.Get("Origin") != "" {
|
||||
headers.Set("Origin", baseURL)
|
||||
}
|
||||
return headers, subprotocols
|
||||
}
|
||||
|
||||
func websocketSubprotocols(headers http.Header) []string {
|
||||
var out []string
|
||||
for _, value := range headers.Values("Sec-WebSocket-Protocol") {
|
||||
for _, part := range strings.Split(value, ",") {
|
||||
protocol := strings.TrimSpace(part)
|
||||
if protocol != "" {
|
||||
out = append(out, protocol)
|
||||
}
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func availableLocalCodePort() string {
|
||||
for port := 8081; port <= 8180; port++ {
|
||||
ln, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@ -111,3 +112,35 @@ func TestCodeFrameType(t *testing.T) {
|
||||
t.Fatalf("websocketMessageType default=%v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketSubprotocols(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Add("Sec-WebSocket-Protocol", "vscode-remote, crabbox")
|
||||
headers.Add("Sec-WebSocket-Protocol", " second-token ")
|
||||
got := websocketSubprotocols(headers)
|
||||
want := []string{"vscode-remote", "crabbox", "second-token"}
|
||||
if strings.Join(got, "|") != strings.Join(want, "|") {
|
||||
t.Fatalf("websocketSubprotocols=%q want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodeWebSocketDialHeadersRewritesOrigin(t *testing.T) {
|
||||
headers, subprotocols := codeWebSocketDialHeaders("http://127.0.0.1:8081", map[string]string{
|
||||
"cookie": "vscode-tkn=remote-token",
|
||||
"origin": "https://crabbox.openclaw.ai",
|
||||
"sec-websocket-protocol": "proto-a, proto-b",
|
||||
})
|
||||
|
||||
if headers.Get("Origin") != "http://127.0.0.1:8081" {
|
||||
t.Fatalf("origin=%q", headers.Get("Origin"))
|
||||
}
|
||||
if headers.Get("Cookie") != "vscode-tkn=remote-token" {
|
||||
t.Fatalf("cookie=%q", headers.Get("Cookie"))
|
||||
}
|
||||
if headers.Get("Sec-WebSocket-Protocol") != "" {
|
||||
t.Fatalf("raw subprotocol header should be removed: %q", headers.Get("Sec-WebSocket-Protocol"))
|
||||
}
|
||||
if strings.Join(subprotocols, "|") != "proto-a|proto-b" {
|
||||
t.Fatalf("subprotocols=%q", subprotocols)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1840,7 +1840,7 @@ function codeLeaseError(lease: LeaseRecord): string {
|
||||
return "";
|
||||
}
|
||||
|
||||
function codeForwardHeaders(headers: Headers): Record<string, string> {
|
||||
export function codeForwardHeaders(headers: Headers): Record<string, string> {
|
||||
const out: Record<string, string> = {};
|
||||
const allowed = new Set([
|
||||
"accept",
|
||||
@ -1856,11 +1856,24 @@ function codeForwardHeaders(headers: Headers): Record<string, string> {
|
||||
const lower = key.toLowerCase();
|
||||
if (allowed.has(lower) || lower.startsWith("x-")) {
|
||||
out[lower] = value;
|
||||
} else if (lower === "cookie") {
|
||||
const cookie = codeForwardCookie(value);
|
||||
if (cookie) {
|
||||
out["cookie"] = cookie;
|
||||
}
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function codeForwardCookie(value: string): string | undefined {
|
||||
const tokens = value
|
||||
.split(";")
|
||||
.map((part) => part.trim())
|
||||
.filter((part) => part.startsWith("vscode-tkn="));
|
||||
return tokens.length > 0 ? tokens.join("; ") : undefined;
|
||||
}
|
||||
|
||||
const codePortalContentSecurityPolicy = [
|
||||
"default-src 'self'",
|
||||
"base-uri 'self'",
|
||||
|
||||
@ -2,6 +2,7 @@ import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
import {
|
||||
FleetDurableObject,
|
||||
codeForwardHeaders,
|
||||
codeResponseHeaders,
|
||||
flushPendingWebVNC,
|
||||
forwardOrBufferWebVNC,
|
||||
@ -585,6 +586,19 @@ describe("fleet lease identity and idle", () => {
|
||||
expect(headers.get("cache-control")).toBe("no-store, no-transform");
|
||||
});
|
||||
|
||||
it("forwards only the VS Code token cookie to code-server", () => {
|
||||
const headers = codeForwardHeaders(
|
||||
new Headers({
|
||||
cookie: "crabbox_session=secret; vscode-tkn=remote-token; other=value",
|
||||
origin: "https://crabbox.openclaw.ai",
|
||||
}),
|
||||
);
|
||||
|
||||
expect(headers["cookie"]).toBe("vscode-tkn=remote-token");
|
||||
expect(headers["cookie"]).not.toContain("crabbox_session");
|
||||
expect(headers.origin).toBe("https://crabbox.openclaw.ai");
|
||||
});
|
||||
|
||||
it("serves WebVNC pages only for desktop leases and requires an agent upgrade", async () => {
|
||||
const storage = new MemoryStorage();
|
||||
const fleet = testFleet(storage);
|
||||
|
||||
Loading…
Reference in New Issue
Block a user