fix(worker): persist portal bridge sockets

This commit is contained in:
Vincent Koc 2026-05-05 00:45:33 -07:00
parent 3395922222
commit 58fe2d85f3
No known key found for this signature in database

View File

@ -99,6 +99,12 @@ interface CodeWebSocketClose {
reason?: string;
}
type BridgeAttachment =
| { kind: "webvnc-agent"; leaseID: string }
| { kind: "webvnc-viewer"; leaseID: string }
| { kind: "code-agent"; leaseID: string }
| { kind: "code-viewer"; leaseID: string; id: string };
export class FleetDurableObject implements DurableObject {
private readonly webVNCAgents = new Map<string, WebSocket>();
private readonly webVNCViewers = new Map<string, WebSocket>();
@ -114,7 +120,9 @@ export class FleetDurableObject implements DurableObject {
private readonly state: DurableObjectState,
private readonly env: Env,
private readonly testProviders: Partial<Record<Provider, CloudProvider>> = {},
) {}
) {
this.restoreBridgeWebSockets();
}
async fetch(request: Request): Promise<Response> {
try {
@ -220,6 +228,130 @@ export class FleetDurableObject implements DurableObject {
}
}
async webSocketMessage(socket: WebSocket, message: string | ArrayBuffer): Promise<void> {
const attachment = bridgeAttachment(socket);
if (!attachment) {
return;
}
await this.handleBridgeMessage(socket, attachment, message);
}
webSocketClose(socket: WebSocket, code: number, reason: string, _wasClean: boolean): void {
this.handleBridgeClose(socket, code, reason);
}
webSocketError(socket: WebSocket, _error: unknown): void {
this.handleBridgeClose(socket, 1011, "bridge socket error");
}
private restoreBridgeWebSockets(): void {
if (typeof this.state.getWebSockets !== "function") {
return;
}
for (const socket of this.state.getWebSockets()) {
const attachment = bridgeAttachment(socket);
if (!attachment || socket.readyState !== WebSocket.OPEN) {
continue;
}
this.trackBridgeSocket(socket, attachment);
}
}
private acceptBridgeWebSocket(socket: WebSocket, attachment: BridgeAttachment): void {
if (typeof this.state.acceptWebSocket === "function") {
this.state.acceptWebSocket(socket, bridgeTags(attachment));
socket.serializeAttachment(attachment);
} else {
socket.accept();
socket.addEventListener("message", (event) => {
void this.handleBridgeMessage(socket, attachment, event.data);
});
socket.addEventListener("close", (event) => {
this.handleBridgeClose(socket, event.code, event.reason);
});
socket.addEventListener("error", () => {
this.handleBridgeClose(socket, 1011, "bridge socket error");
});
}
}
private trackBridgeSocket(socket: WebSocket, attachment: BridgeAttachment): void {
switch (attachment.kind) {
case "webvnc-agent":
this.webVNCAgents.set(attachment.leaseID, socket);
break;
case "webvnc-viewer":
this.webVNCViewers.set(attachment.leaseID, socket);
break;
case "code-agent":
this.codeAgents.set(attachment.leaseID, socket);
break;
case "code-viewer":
this.codeViewers.set(attachment.id, socket);
break;
}
}
private async handleBridgeMessage(
socket: WebSocket,
attachment: BridgeAttachment,
message: string | ArrayBuffer | Blob,
): Promise<void> {
switch (attachment.kind) {
case "webvnc-agent":
await forwardOrBufferWebVNC(
message,
this.webVNCViewers.get(attachment.leaseID),
this.pendingWebVNCToViewer,
attachment.leaseID,
);
break;
case "webvnc-viewer":
await forwardWebVNC(message, this.webVNCAgents.get(attachment.leaseID));
break;
case "code-agent":
await this.handleCodeAgentMessage(attachment.leaseID, message);
break;
case "code-viewer": {
const agent = this.codeAgents.get(attachment.leaseID);
if (agent?.readyState !== WebSocket.OPEN) {
return;
}
const data = await normalizeWebVNCData(message);
const bytes = typeof data === "string" ? textEncoder.encode(data) : new Uint8Array(data);
const outbound: CodeWebSocketData = {
type: "ws_data",
id: attachment.id,
body: bytesToBase64(bytes),
};
agent.send(JSON.stringify(outbound));
break;
}
}
void socket;
}
private handleBridgeClose(socket: WebSocket, code: number, reason: string): void {
const attachment = bridgeAttachment(socket);
if (!attachment) {
return;
}
switch (attachment.kind) {
case "webvnc-agent":
this.clearWebVNCAgent(attachment.leaseID, socket);
break;
case "webvnc-viewer":
this.clearWebVNCViewer(attachment.leaseID, socket);
break;
case "code-agent":
this.clearCodeAgent(attachment.leaseID, socket);
break;
case "code-viewer":
this.clearCodeViewer(attachment.leaseID, attachment.id, socket, code, reason);
break;
}
}
async alarm(): Promise<void> {
await this.expireLeases();
await this.scheduleAlarm();
@ -546,21 +678,11 @@ export class FleetDurableObject implements DurableObject {
const pair = new WebSocketPair();
const client = pair[0];
const agent = pair[1];
agent.accept();
closeSocket(this.webVNCAgents.get(lease.id), 1012, "replaced by a newer WebVNC bridge");
this.pendingWebVNCToViewer.delete(lease.id);
this.webVNCAgents.set(lease.id, agent);
agent.addEventListener("message", (event) => {
void forwardOrBufferWebVNC(
event.data,
this.webVNCViewers.get(lease.id),
this.pendingWebVNCToViewer,
lease.id,
);
});
agent.addEventListener("close", () => this.clearWebVNCAgent(lease.id, agent));
agent.addEventListener("error", () => this.clearWebVNCAgent(lease.id, agent));
this.acceptBridgeWebSocket(agent, { kind: "webvnc-agent", leaseID: lease.id });
return new Response(null, { status: 101, webSocket: client });
}
@ -675,16 +797,11 @@ export class FleetDurableObject implements DurableObject {
const pair = new WebSocketPair();
const client = pair[0];
const agent = pair[1];
agent.accept();
closeSocket(this.codeAgents.get(lease.id), 1012, "replaced by a newer code bridge");
this.clearCodeLease(lease.id);
this.codeAgents.set(lease.id, agent);
agent.addEventListener("message", (event) => {
void this.handleCodeAgentMessage(lease.id, event.data);
});
agent.addEventListener("close", () => this.clearCodeAgent(lease.id, agent));
agent.addEventListener("error", () => this.clearCodeAgent(lease.id, agent));
this.acceptBridgeWebSocket(agent, { kind: "code-agent", leaseID: lease.id });
return new Response(null, { status: 101, webSocket: client });
}
@ -760,9 +877,9 @@ export class FleetDurableObject implements DurableObject {
const pair = new WebSocketPair();
const client = pair[0];
const viewer = pair[1];
viewer.accept();
const id = crypto.randomUUID();
this.codeViewers.set(id, viewer);
this.acceptBridgeWebSocket(viewer, { kind: "code-viewer", leaseID: lease.id, id });
const url = new URL(request.url);
const open: CodeWebSocketOpen = {
type: "ws_open",
@ -771,23 +888,6 @@ export class FleetDurableObject implements DurableObject {
headers: codeForwardHeaders(request.headers),
};
agent.send(JSON.stringify(open));
viewer.addEventListener("message", (event) => {
void (async () => {
const data = await normalizeWebVNCData(event.data);
const bytes = typeof data === "string" ? textEncoder.encode(data) : new Uint8Array(data);
const message: CodeWebSocketData = { type: "ws_data", id, body: bytesToBase64(bytes) };
agent.send(JSON.stringify(message));
})();
});
const close = (code = 1000, reason = "viewer closed") => {
this.codeViewers.delete(id);
const message: CodeWebSocketClose = { type: "ws_close", id, code, reason };
if (agent.readyState === WebSocket.OPEN) {
agent.send(JSON.stringify(message));
}
};
viewer.addEventListener("close", (event) => close(event.code, event.reason));
viewer.addEventListener("error", () => close(1011, "viewer error"));
return new Response(null, { status: 101, webSocket: client });
}
@ -842,6 +942,24 @@ export class FleetDurableObject implements DurableObject {
this.clearCodeLease(leaseID);
}
private clearCodeViewer(
leaseID: string,
id: string,
socket: WebSocket,
code = 1000,
reason = "viewer closed",
): void {
if (this.codeViewers.get(id) !== socket) {
return;
}
this.codeViewers.delete(id);
const agent = this.codeAgents.get(leaseID);
const message: CodeWebSocketClose = { type: "ws_close", id, code, reason };
if (agent?.readyState === WebSocket.OPEN) {
agent.send(JSON.stringify(message));
}
}
private clearCodeLease(_leaseID: string): void {
for (const [id, viewer] of this.codeViewers) {
this.codeViewers.delete(id);
@ -899,15 +1017,10 @@ export class FleetDurableObject implements DurableObject {
const pair = new WebSocketPair();
const client = pair[0];
const viewer = pair[1];
viewer.accept();
this.webVNCViewers.set(lease.id, viewer);
this.acceptBridgeWebSocket(viewer, { kind: "webvnc-viewer", leaseID: lease.id });
flushPendingWebVNC(this.pendingWebVNCToViewer, lease.id, viewer);
viewer.addEventListener("message", (event) => {
void forwardWebVNC(event.data, this.webVNCAgents.get(lease.id));
});
viewer.addEventListener("close", () => this.clearWebVNCViewer(lease.id, viewer));
viewer.addEventListener("error", () => this.clearWebVNCViewer(lease.id, viewer));
return new Response(null, { status: 101, webSocket: client });
}
@ -1745,6 +1858,29 @@ export function codeResponseHeaders(values: Record<string, string>): Headers {
return headers;
}
function bridgeAttachment(socket: WebSocket): BridgeAttachment | undefined {
const attachment = socket.deserializeAttachment?.() as BridgeAttachment | undefined;
if (!attachment || typeof attachment !== "object") {
return undefined;
}
switch (attachment.kind) {
case "webvnc-agent":
case "webvnc-viewer":
case "code-agent":
return typeof attachment.leaseID === "string" ? attachment : undefined;
case "code-viewer":
return typeof attachment.leaseID === "string" && typeof attachment.id === "string"
? attachment
: undefined;
default:
return undefined;
}
}
function bridgeTags(attachment: BridgeAttachment): string[] {
return [`lease:${attachment.leaseID}`, attachment.kind];
}
function bytesToBase64(bytes: Uint8Array): string {
let binary = "";
for (const byte of bytes) {