From 58fe2d85f3838eafe2e5e2aa296d81ecd7ec0e72 Mon Sep 17 00:00:00 2001 From: Vincent Koc Date: Tue, 5 May 2026 00:45:33 -0700 Subject: [PATCH] fix(worker): persist portal bridge sockets --- worker/src/fleet.ts | 220 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 178 insertions(+), 42 deletions(-) diff --git a/worker/src/fleet.ts b/worker/src/fleet.ts index abdb137..a132533 100644 --- a/worker/src/fleet.ts +++ b/worker/src/fleet.ts @@ -99,6 +99,12 @@ interface CodeWebSocketClose { reason?: string; } +type BridgeAttachment = + | { kind: "webvnc-agent"; leaseID: string } + | { kind: "webvnc-viewer"; leaseID: string } + | { kind: "code-agent"; leaseID: string } + | { kind: "code-viewer"; leaseID: string; id: string }; + export class FleetDurableObject implements DurableObject { private readonly webVNCAgents = new Map(); private readonly webVNCViewers = new Map(); @@ -114,7 +120,9 @@ export class FleetDurableObject implements DurableObject { private readonly state: DurableObjectState, private readonly env: Env, private readonly testProviders: Partial> = {}, - ) {} + ) { + this.restoreBridgeWebSockets(); + } async fetch(request: Request): Promise { try { @@ -220,6 +228,130 @@ export class FleetDurableObject implements DurableObject { } } + async webSocketMessage(socket: WebSocket, message: string | ArrayBuffer): Promise { + const attachment = bridgeAttachment(socket); + if (!attachment) { + return; + } + await this.handleBridgeMessage(socket, attachment, message); + } + + webSocketClose(socket: WebSocket, code: number, reason: string, _wasClean: boolean): void { + this.handleBridgeClose(socket, code, reason); + } + + webSocketError(socket: WebSocket, _error: unknown): void { + this.handleBridgeClose(socket, 1011, "bridge socket error"); + } + + private restoreBridgeWebSockets(): void { + if (typeof this.state.getWebSockets !== "function") { + return; + } + for (const socket of this.state.getWebSockets()) { + const attachment = bridgeAttachment(socket); + if (!attachment || socket.readyState !== WebSocket.OPEN) { + continue; + } + this.trackBridgeSocket(socket, attachment); + } + } + + private acceptBridgeWebSocket(socket: WebSocket, attachment: BridgeAttachment): void { + if (typeof this.state.acceptWebSocket === "function") { + this.state.acceptWebSocket(socket, bridgeTags(attachment)); + socket.serializeAttachment(attachment); + } else { + socket.accept(); + socket.addEventListener("message", (event) => { + void this.handleBridgeMessage(socket, attachment, event.data); + }); + socket.addEventListener("close", (event) => { + this.handleBridgeClose(socket, event.code, event.reason); + }); + socket.addEventListener("error", () => { + this.handleBridgeClose(socket, 1011, "bridge socket error"); + }); + } + } + + private trackBridgeSocket(socket: WebSocket, attachment: BridgeAttachment): void { + switch (attachment.kind) { + case "webvnc-agent": + this.webVNCAgents.set(attachment.leaseID, socket); + break; + case "webvnc-viewer": + this.webVNCViewers.set(attachment.leaseID, socket); + break; + case "code-agent": + this.codeAgents.set(attachment.leaseID, socket); + break; + case "code-viewer": + this.codeViewers.set(attachment.id, socket); + break; + } + } + + private async handleBridgeMessage( + socket: WebSocket, + attachment: BridgeAttachment, + message: string | ArrayBuffer | Blob, + ): Promise { + switch (attachment.kind) { + case "webvnc-agent": + await forwardOrBufferWebVNC( + message, + this.webVNCViewers.get(attachment.leaseID), + this.pendingWebVNCToViewer, + attachment.leaseID, + ); + break; + case "webvnc-viewer": + await forwardWebVNC(message, this.webVNCAgents.get(attachment.leaseID)); + break; + case "code-agent": + await this.handleCodeAgentMessage(attachment.leaseID, message); + break; + case "code-viewer": { + const agent = this.codeAgents.get(attachment.leaseID); + if (agent?.readyState !== WebSocket.OPEN) { + return; + } + const data = await normalizeWebVNCData(message); + const bytes = typeof data === "string" ? textEncoder.encode(data) : new Uint8Array(data); + const outbound: CodeWebSocketData = { + type: "ws_data", + id: attachment.id, + body: bytesToBase64(bytes), + }; + agent.send(JSON.stringify(outbound)); + break; + } + } + void socket; + } + + private handleBridgeClose(socket: WebSocket, code: number, reason: string): void { + const attachment = bridgeAttachment(socket); + if (!attachment) { + return; + } + switch (attachment.kind) { + case "webvnc-agent": + this.clearWebVNCAgent(attachment.leaseID, socket); + break; + case "webvnc-viewer": + this.clearWebVNCViewer(attachment.leaseID, socket); + break; + case "code-agent": + this.clearCodeAgent(attachment.leaseID, socket); + break; + case "code-viewer": + this.clearCodeViewer(attachment.leaseID, attachment.id, socket, code, reason); + break; + } + } + async alarm(): Promise { await this.expireLeases(); await this.scheduleAlarm(); @@ -546,21 +678,11 @@ export class FleetDurableObject implements DurableObject { const pair = new WebSocketPair(); const client = pair[0]; const agent = pair[1]; - agent.accept(); closeSocket(this.webVNCAgents.get(lease.id), 1012, "replaced by a newer WebVNC bridge"); this.pendingWebVNCToViewer.delete(lease.id); this.webVNCAgents.set(lease.id, agent); - agent.addEventListener("message", (event) => { - void forwardOrBufferWebVNC( - event.data, - this.webVNCViewers.get(lease.id), - this.pendingWebVNCToViewer, - lease.id, - ); - }); - agent.addEventListener("close", () => this.clearWebVNCAgent(lease.id, agent)); - agent.addEventListener("error", () => this.clearWebVNCAgent(lease.id, agent)); + this.acceptBridgeWebSocket(agent, { kind: "webvnc-agent", leaseID: lease.id }); return new Response(null, { status: 101, webSocket: client }); } @@ -675,16 +797,11 @@ export class FleetDurableObject implements DurableObject { const pair = new WebSocketPair(); const client = pair[0]; const agent = pair[1]; - agent.accept(); closeSocket(this.codeAgents.get(lease.id), 1012, "replaced by a newer code bridge"); this.clearCodeLease(lease.id); this.codeAgents.set(lease.id, agent); - agent.addEventListener("message", (event) => { - void this.handleCodeAgentMessage(lease.id, event.data); - }); - agent.addEventListener("close", () => this.clearCodeAgent(lease.id, agent)); - agent.addEventListener("error", () => this.clearCodeAgent(lease.id, agent)); + this.acceptBridgeWebSocket(agent, { kind: "code-agent", leaseID: lease.id }); return new Response(null, { status: 101, webSocket: client }); } @@ -760,9 +877,9 @@ export class FleetDurableObject implements DurableObject { const pair = new WebSocketPair(); const client = pair[0]; const viewer = pair[1]; - viewer.accept(); const id = crypto.randomUUID(); this.codeViewers.set(id, viewer); + this.acceptBridgeWebSocket(viewer, { kind: "code-viewer", leaseID: lease.id, id }); const url = new URL(request.url); const open: CodeWebSocketOpen = { type: "ws_open", @@ -771,23 +888,6 @@ export class FleetDurableObject implements DurableObject { headers: codeForwardHeaders(request.headers), }; agent.send(JSON.stringify(open)); - viewer.addEventListener("message", (event) => { - void (async () => { - const data = await normalizeWebVNCData(event.data); - const bytes = typeof data === "string" ? textEncoder.encode(data) : new Uint8Array(data); - const message: CodeWebSocketData = { type: "ws_data", id, body: bytesToBase64(bytes) }; - agent.send(JSON.stringify(message)); - })(); - }); - const close = (code = 1000, reason = "viewer closed") => { - this.codeViewers.delete(id); - const message: CodeWebSocketClose = { type: "ws_close", id, code, reason }; - if (agent.readyState === WebSocket.OPEN) { - agent.send(JSON.stringify(message)); - } - }; - viewer.addEventListener("close", (event) => close(event.code, event.reason)); - viewer.addEventListener("error", () => close(1011, "viewer error")); return new Response(null, { status: 101, webSocket: client }); } @@ -842,6 +942,24 @@ export class FleetDurableObject implements DurableObject { this.clearCodeLease(leaseID); } + private clearCodeViewer( + leaseID: string, + id: string, + socket: WebSocket, + code = 1000, + reason = "viewer closed", + ): void { + if (this.codeViewers.get(id) !== socket) { + return; + } + this.codeViewers.delete(id); + const agent = this.codeAgents.get(leaseID); + const message: CodeWebSocketClose = { type: "ws_close", id, code, reason }; + if (agent?.readyState === WebSocket.OPEN) { + agent.send(JSON.stringify(message)); + } + } + private clearCodeLease(_leaseID: string): void { for (const [id, viewer] of this.codeViewers) { this.codeViewers.delete(id); @@ -899,15 +1017,10 @@ export class FleetDurableObject implements DurableObject { const pair = new WebSocketPair(); const client = pair[0]; const viewer = pair[1]; - viewer.accept(); this.webVNCViewers.set(lease.id, viewer); + this.acceptBridgeWebSocket(viewer, { kind: "webvnc-viewer", leaseID: lease.id }); flushPendingWebVNC(this.pendingWebVNCToViewer, lease.id, viewer); - viewer.addEventListener("message", (event) => { - void forwardWebVNC(event.data, this.webVNCAgents.get(lease.id)); - }); - viewer.addEventListener("close", () => this.clearWebVNCViewer(lease.id, viewer)); - viewer.addEventListener("error", () => this.clearWebVNCViewer(lease.id, viewer)); return new Response(null, { status: 101, webSocket: client }); } @@ -1745,6 +1858,29 @@ export function codeResponseHeaders(values: Record): Headers { return headers; } +function bridgeAttachment(socket: WebSocket): BridgeAttachment | undefined { + const attachment = socket.deserializeAttachment?.() as BridgeAttachment | undefined; + if (!attachment || typeof attachment !== "object") { + return undefined; + } + switch (attachment.kind) { + case "webvnc-agent": + case "webvnc-viewer": + case "code-agent": + return typeof attachment.leaseID === "string" ? attachment : undefined; + case "code-viewer": + return typeof attachment.leaseID === "string" && typeof attachment.id === "string" + ? attachment + : undefined; + default: + return undefined; + } +} + +function bridgeTags(attachment: BridgeAttachment): string[] { + return [`lease:${attachment.leaseID}`, attachment.kind]; +} + function bytesToBase64(bytes: Uint8Array): string { let binary = ""; for (const byte of bytes) {