crabbox/internal/cli/code.go
2026-05-05 02:34:35 -07:00

468 lines
13 KiB
Go

package cli
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
"nhooyr.io/websocket"
)
type coordinatorCodeTicket struct {
Ticket string `json:"ticket"`
LeaseID string `json:"leaseID"`
ExpiresAt string `json:"expiresAt"`
}
type codeProxyMessage struct {
Type string `json:"type"`
ID string `json:"id"`
Method string `json:"method,omitempty"`
Path string `json:"path,omitempty"`
Headers map[string]string `json:"headers,omitempty"`
Status int `json:"status,omitempty"`
Body string `json:"body,omitempty"`
Error string `json:"error,omitempty"`
Code int `json:"code,omitempty"`
Reason string `json:"reason,omitempty"`
}
const maxCodeBridgeBodyChunkBytes = 63 * 1024
func (a App) webCode(ctx context.Context, args []string) error {
defaults := defaultConfig()
fs := newFlagSet("code", a.Stderr)
provider := fs.String("provider", defaults.Provider, "provider: hetzner or aws")
id := fs.String("id", "", "lease id or slug")
reclaim := fs.Bool("reclaim", false, "claim this lease for the current repo")
localPort := fs.String("local-port", "", "local code-server tunnel port")
openPortal := fs.Bool("open", false, "open the web portal code page")
networkFlags := registerNetworkModeFlag(fs, defaults)
targetFlags := registerTargetFlags(fs, defaults)
if err := parseFlags(fs, args); err != nil {
return err
}
if *id == "" && fs.NArg() > 0 {
*id = fs.Arg(0)
}
if *id == "" {
return exit(2, "usage: crabbox code --id <lease-id-or-slug>")
}
cfg, err := loadConfig()
if err != nil {
return err
}
cfg.Provider = *provider
cfg.Code = true
if err := applyNetworkModeFlagOverride(&cfg, fs, networkFlags); err != nil {
return err
}
if err := applyTargetFlagOverrides(&cfg, fs, targetFlags); err != nil {
return err
}
if err := validateRequestedCapabilities(cfg); err != nil {
return err
}
if isBlacksmithProvider(cfg.Provider) || isStaticProvider(cfg.Provider) {
return exit(2, "code currently supports coordinator-backed hetzner/aws Linux leases")
}
coord, useCoordinator, err := newTargetCoordinatorClient(cfg)
if err != nil {
return err
}
if !useCoordinator || coord == nil || coord.Token == "" {
return exit(2, "code requires a configured coordinator login; run crabbox login first")
}
server, target, leaseID, err := a.resolveLeaseTarget(ctx, cfg, *id)
if err != nil {
return err
}
if resolved, err := resolveNetworkTarget(ctx, cfg, server, target); err != nil {
return err
} else {
target = resolved.Target
}
if err := enforceManagedLeaseCapabilities(cfg, server, leaseID); err != nil {
return err
}
repo, err := findRepo()
if err != nil {
return err
}
if err := claimLeaseForRepoConfig(leaseID, serverSlug(server), cfg, repo.Root, cfg.IdleTimeout, *reclaim); err != nil {
return err
}
a.touchActiveLeaseBestEffort(ctx, cfg, server, leaseID)
workdir := remoteJoin(cfg, leaseID, repo.Name)
if err := ensureRemoteCodeServer(ctx, target, workdir); err != nil {
return err
}
if *localPort == "" {
*localPort = availableLocalCodePort()
}
tunnel, err := startVNCForegroundTunnel(ctx, target, *localPort, "127.0.0.1", managedCodePort)
if err != nil {
return err
}
defer stopProcess(tunnel)
portal := webCodePortalURL(coord.BaseURL, leaseID)
opened := false
for {
bridge, err := connectCodeBridge(ctx, coord, leaseID, "127.0.0.1", *localPort)
if err != nil {
return err
}
fmt.Fprintln(a.Stdout, "bridge: connected; keep this process running while using Code")
fmt.Fprintf(a.Stdout, "code: %s\n", portal)
if *openPortal && !opened {
if err := openLocalURL(portal); err != nil {
bridge.Close(websocket.StatusNormalClosure, "bridge stopped")
return err
}
opened = true
fmt.Fprintf(a.Stdout, "opened: %s\n", portal)
}
err = bridge.Serve(ctx)
if ctx.Err() != nil {
return context.Cause(ctx)
}
if !isRetryableCodeBridgeError(err) {
return err
}
fmt.Fprintln(a.Stdout, "bridge: disconnected; reconnecting")
time.Sleep(300 * time.Millisecond)
}
}
func ensureRemoteCodeServer(ctx context.Context, target SSHTarget, workdir string) error {
if err := runSSHQuiet(ctx, target, codeServerReadyCommand()); err == nil {
return nil
}
if err := runSSHQuiet(ctx, target, startCodeServerCommand(workdir)); err != nil {
return exit(5, "start code-server: %v", err)
}
deadline := time.Now().Add(20 * time.Second)
for time.Now().Before(deadline) {
if ctx.Err() != nil {
return context.Cause(ctx)
}
if err := runSSHQuiet(ctx, target, codeServerReadyCommand()); err == nil {
return nil
}
time.Sleep(500 * time.Millisecond)
}
return exit(5, "timed out waiting for code-server on 127.0.0.1:%s", managedCodePort)
}
func codeServerReadyCommand() string {
return "curl -fsS http://127.0.0.1:" + managedCodePort + "/healthz >/dev/null || curl -fsS http://127.0.0.1:" + managedCodePort + "/ >/dev/null"
}
func startCodeServerCommand(workdir string) string {
pidfile := "/tmp/crabbox-code-server.pid"
return strings.Join([]string{
"mkdir -p " + shellQuote(workdir),
"pidfile=" + shellQuote(pidfile) + "; if [ -s \"$pidfile\" ]; then oldpid=$(cat \"$pidfile\" 2>/dev/null || true); if [ -n \"$oldpid\" ] && kill -0 \"$oldpid\" 2>/dev/null; then kill \"$oldpid\" 2>/dev/null || true; fi; fi",
"(nohup env VSCODE_PROXY_URI='./proxy/{{port}}' " + codeServerBinary +
" --auth none --bind-addr 127.0.0.1:" + managedCodePort +
" --disable-telemetry --disable-update-check " + shellQuote(workdir) +
" >/tmp/crabbox-code-server.log 2>&1 & echo $! >" + shellQuote(pidfile) + ")",
}, " && ")
}
type codeBridge struct {
ws *websocket.Conn
baseURL string
client *http.Client
mu sync.Mutex
writeMu sync.Mutex
upstream map[string]*websocket.Conn
}
func connectCodeBridge(ctx context.Context, coord *CoordinatorClient, leaseID, host, port string) (*codeBridge, error) {
ticket, err := coord.CreateCodeTicket(ctx, leaseID)
if err != nil {
return nil, err
}
ws, _, err := websocket.Dial(ctx, webCodeAgentURL(coord.BaseURL, leaseID, ticket.Ticket), &websocket.DialOptions{
HTTPHeader: coord.webVNCAccessHeaders(),
})
if err != nil {
return nil, err
}
return &codeBridge{
ws: ws,
baseURL: "http://" + host + ":" + port,
client: &http.Client{
Timeout: 30 * time.Second,
},
upstream: map[string]*websocket.Conn{},
}, nil
}
func (b *codeBridge) Serve(ctx context.Context) error {
defer b.Close(websocket.StatusNormalClosure, "bridge stopped")
for {
_, data, err := b.ws.Read(ctx)
if err != nil {
return err
}
var msg codeProxyMessage
if err := json.Unmarshal(data, &msg); err != nil {
continue
}
switch msg.Type {
case "http":
go b.handleHTTP(ctx, msg)
case "ws_open":
go b.openUpstreamWebSocket(ctx, msg)
case "ws_data":
b.writeUpstreamWebSocket(ctx, msg)
case "ws_close":
b.closeUpstreamWebSocket(msg.ID, websocket.StatusCode(msg.Code), msg.Reason)
}
}
}
func (b *codeBridge) Close(code websocket.StatusCode, reason string) {
if b == nil {
return
}
if b.ws != nil {
_ = b.ws.Close(code, reason)
}
b.mu.Lock()
defer b.mu.Unlock()
for id, conn := range b.upstream {
_ = conn.Close(websocket.StatusNormalClosure, "bridge stopped")
delete(b.upstream, id)
}
}
func (b *codeBridge) handleHTTP(ctx context.Context, msg codeProxyMessage) {
body, _ := base64.StdEncoding.DecodeString(msg.Body)
upstream := b.baseURL + codeUpstreamPath(msg.Path)
req, err := http.NewRequestWithContext(ctx, msg.Method, upstream, bytes.NewReader(body))
if err != nil {
_ = b.writeJSON(ctx, codeProxyMessage{Type: "http", ID: msg.ID, Status: 502, Error: err.Error()})
return
}
for key, value := range msg.Headers {
req.Header.Set(key, value)
}
resp, err := b.client.Do(req)
if err != nil {
_ = b.writeJSON(ctx, codeProxyMessage{Type: "http", ID: msg.ID, Status: 502, Error: err.Error()})
return
}
defer resp.Body.Close()
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 25*1024*1024))
if err != nil {
_ = b.writeJSON(ctx, codeProxyMessage{Type: "http", ID: msg.ID, Status: 502, Error: err.Error()})
return
}
if isCodeHTML(resp.Header.Get("content-type")) {
respBody = rewriteCodeHTML(respBody)
}
headers := map[string]string{}
for key, values := range resp.Header {
if len(values) > 0 {
headers[key] = values[0]
}
}
message := codeProxyMessage{
Type: "http",
ID: msg.ID,
Status: resp.StatusCode,
Headers: headers,
}
if len(respBody) <= maxCodeBridgeBodyChunkBytes {
message.Body = base64.StdEncoding.EncodeToString(respBody)
_ = b.writeJSON(ctx, message)
return
}
message.Type = "http_start"
if err := b.writeJSON(ctx, message); err != nil {
return
}
for len(respBody) > 0 {
n := min(len(respBody), maxCodeBridgeBodyChunkBytes)
if err := b.writeJSON(ctx, codeProxyMessage{
Type: "http_body",
ID: msg.ID,
Body: base64.StdEncoding.EncodeToString(respBody[:n]),
}); err != nil {
return
}
respBody = respBody[n:]
}
_ = b.writeJSON(ctx, codeProxyMessage{Type: "http_end", ID: msg.ID})
}
func (b *codeBridge) openUpstreamWebSocket(ctx context.Context, msg codeProxyMessage) {
upstream := "ws" + strings.TrimPrefix(b.baseURL, "http") + codeUpstreamPath(msg.Path)
header := http.Header{}
for key, value := range msg.Headers {
header.Set(key, value)
}
conn, _, err := websocket.Dial(ctx, upstream, &websocket.DialOptions{HTTPHeader: header})
if err != nil {
_ = b.writeJSON(ctx, codeProxyMessage{Type: "ws_close", ID: msg.ID, Code: int(websocket.StatusInternalError), Reason: err.Error()})
return
}
b.mu.Lock()
b.upstream[msg.ID] = conn
b.mu.Unlock()
go b.readUpstreamWebSocket(ctx, msg.ID, conn)
}
func (b *codeBridge) readUpstreamWebSocket(ctx context.Context, id string, conn *websocket.Conn) {
for {
_, data, err := conn.Read(ctx)
if err != nil {
reason := err.Error()
var closeErr websocket.CloseError
code := int(websocket.StatusNormalClosure)
if errors.As(err, &closeErr) {
code = int(closeErr.Code)
reason = closeErr.Reason
}
_ = b.writeJSON(ctx, codeProxyMessage{Type: "ws_close", ID: id, Code: code, Reason: reason})
b.closeUpstreamWebSocket(id, websocket.StatusNormalClosure, "closed")
return
}
_ = b.writeJSON(ctx, codeProxyMessage{Type: "ws_data", ID: id, Body: base64.StdEncoding.EncodeToString(data)})
}
}
func (b *codeBridge) writeUpstreamWebSocket(ctx context.Context, msg codeProxyMessage) {
data, _ := base64.StdEncoding.DecodeString(msg.Body)
b.mu.Lock()
conn := b.upstream[msg.ID]
b.mu.Unlock()
if conn == nil {
return
}
_ = conn.Write(ctx, websocket.MessageBinary, data)
}
func (b *codeBridge) closeUpstreamWebSocket(id string, code websocket.StatusCode, reason string) {
if code == 0 {
code = websocket.StatusNormalClosure
}
b.mu.Lock()
conn := b.upstream[id]
delete(b.upstream, id)
b.mu.Unlock()
if conn != nil {
_ = conn.Close(code, reason)
}
}
func (b *codeBridge) writeJSON(ctx context.Context, msg codeProxyMessage) error {
data, err := json.Marshal(msg)
if err != nil {
return err
}
b.writeMu.Lock()
defer b.writeMu.Unlock()
return b.ws.Write(ctx, websocket.MessageText, data)
}
func isRetryableCodeBridgeError(err error) bool {
if err == nil {
return false
}
var closeErr websocket.CloseError
if errors.As(err, &closeErr) {
return closeErr.Code == websocket.StatusInternalError || closeErr.Code == websocket.StatusServiceRestart
}
text := err.Error()
return strings.Contains(text, "failed to read frame header: EOF") ||
strings.Contains(text, "tls: bad record MAC")
}
func codeUpstreamPath(path string) string {
u, err := url.Parse(path)
if err != nil {
return "/"
}
parts := strings.Split(strings.TrimPrefix(u.Path, "/"), "/")
if len(parts) >= 4 && parts[0] == "portal" && parts[1] == "leases" && parts[3] == "code" {
tail := strings.Join(parts[4:], "/")
if tail == "" {
u.Path = "/"
} else {
u.Path = "/" + tail
}
return u.RequestURI()
}
return u.RequestURI()
}
func isCodeHTML(contentType string) bool {
return strings.HasPrefix(strings.ToLower(contentType), "text/html")
}
func rewriteCodeHTML(body []byte) []byte {
return bytes.ReplaceAll(body, []byte(`<script type="module" src=""></script>`), nil)
}
func availableLocalCodePort() string {
for port := 8081; port <= 8180; port++ {
ln, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
if err != nil {
continue
}
_ = ln.Close()
return fmt.Sprint(port)
}
return "8081"
}
func webCodeAgentURL(base, leaseID, ticket string) string {
u, err := url.Parse(base)
if err != nil {
return base
}
if u.Scheme == "https" {
u.Scheme = "wss"
} else {
u.Scheme = "ws"
}
u.Path = strings.TrimRight(u.Path, "/") + "/v1/leases/" + url.PathEscape(leaseID) + "/code/agent"
values := url.Values{}
values.Set("ticket", ticket)
u.RawQuery = values.Encode()
u.Fragment = ""
return u.String()
}
func webCodePortalURL(base, leaseID string) string {
u, err := url.Parse(base)
if err != nil {
return base
}
u.Path = strings.TrimRight(u.Path, "/") + "/portal/leases/" + url.PathEscape(leaseID) + "/code/"
u.RawQuery = ""
u.Fragment = ""
return u.String()
}
func (c *CoordinatorClient) CreateCodeTicket(ctx context.Context, leaseID string) (coordinatorCodeTicket, error) {
var res coordinatorCodeTicket
err := c.do(ctx, http.MethodPost, "/v1/leases/"+url.PathEscape(leaseID)+"/code/ticket", map[string]any{}, &res)
return res, err
}