From fa9cc0e6bc445490f443fc69804bf1eb486124c9 Mon Sep 17 00:00:00 2001 From: Vincent Koc Date: Wed, 6 May 2026 15:00:49 -0700 Subject: [PATCH] feat(aws): route capacity across regions --- internal/cli/aws.go | 65 ++++++++++++++++++++++++++++- internal/cli/ssh_test.go | 19 +++++++++ worker/src/aws.ts | 64 ++++++++++++++++++++++++++--- worker/src/fleet.ts | 89 +++++++++++++++++++++++++++++++--------- worker/src/types.ts | 3 ++ worker/test/aws.test.ts | 19 +++++++++ worker/wrangler.jsonc | 1 + 7 files changed, 234 insertions(+), 26 deletions(-) diff --git a/internal/cli/aws.go b/internal/cli/aws.go index 5b217f8..4a9028b 100644 --- a/internal/cli/aws.go +++ b/internal/cli/aws.go @@ -26,11 +26,15 @@ func newAWSClient(ctx context.Context, cfg Config) (*AWSClient, error) { if cfg.AWSRegion == "" { return nil, exit(3, "CRABBOX_AWS_REGION or AWS_REGION is required") } - awsCfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(cfg.AWSRegion)) + return newAWSClientForRegion(ctx, cfg, cfg.AWSRegion) +} + +func newAWSClientForRegion(ctx context.Context, cfg Config, region string) (*AWSClient, error) { + awsCfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(region)) if err != nil { return nil, err } - return &AWSClient{ec2: ec2.NewFromConfig(awsCfg), region: cfg.AWSRegion}, nil + return &AWSClient{ec2: ec2.NewFromConfig(awsCfg), region: region}, nil } func NewAWSClient(ctx context.Context, cfg Config) (*AWSClient, error) { @@ -128,6 +132,39 @@ func (c *AWSClient) DeleteSSHKey(ctx context.Context, name string) error { } func (c *AWSClient) CreateServerWithFallback(ctx context.Context, cfg Config, publicKey, leaseID, slug string, keep bool, logf func(string, ...any)) (Server, Config, error) { + regions := awsRegionCandidates(cfg, c.region) + if len(regions) > 1 { + var errs []error + for _, region := range regions { + next := cfg + next.AWSRegion = region + client := c + if region != c.region { + var err error + client, err = newAWSClientForRegion(ctx, next, region) + if err != nil { + errs = append(errs, fmt.Errorf("%s: %w", region, err)) + continue + } + } + if logf != nil && region != c.region { + logf("fallback provisioning region=%s after capacity/quota rejection\n", region) + } + server, resolved, err := client.createServerWithFallbackInRegion(ctx, next, publicKey, leaseID, slug, keep, logf) + if err == nil { + return server, resolved, nil + } + errs = append(errs, fmt.Errorf("%s: %w", region, err)) + if !isRetryableAWSRegionProvisioningError(err) { + return Server{}, resolved, joinErrors(errs) + } + } + return Server{}, cfg, joinErrors(errs) + } + return c.createServerWithFallbackInRegion(ctx, cfg, publicKey, leaseID, slug, keep, logf) +} + +func (c *AWSClient) createServerWithFallbackInRegion(ctx context.Context, cfg Config, publicKey, leaseID, slug string, keep bool, logf func(string, ...any)) (Server, Config, error) { if cfg.ProviderKey == "" { cfg.ProviderKey = "crabbox-steipete" } @@ -254,6 +291,10 @@ func (c *AWSClient) createServer(ctx context.Context, cfg Config, publicKey, lea applyAWSRunInstanceTargetOptions(input, cfg) if cfg.TargetOS == targetMacOS { input.Placement = &types.Placement{HostId: aws.String(cfg.AWSMacHostID), Tenancy: types.TenancyHost} + } else if cfg.AWSSubnetID == "" { + if zone := awsAvailabilityZoneForRegion(cfg, cfg.AWSRegion); zone != "" { + input.Placement = &types.Placement{AvailabilityZone: aws.String(zone)} + } } out, err := c.ec2.RunInstances(ctx, input) if err != nil { @@ -540,6 +581,26 @@ func isRetryableAWSProvisioningError(err error) bool { strings.Contains(s, "instance type"))) } +func isRetryableAWSRegionProvisioningError(err error) bool { + s := err.Error() + return isRetryableAWSProvisioningError(err) || + strings.Contains(s, "quota ") || + strings.Contains(s, "capacity") +} + +func awsRegionCandidates(cfg Config, preferredRegion string) []string { + return appendUniqueStrings([]string{preferredRegion, cfg.AWSRegion}, cfg.Capacity.Regions...) +} + +func awsAvailabilityZoneForRegion(cfg Config, region string) string { + for _, zone := range cfg.Capacity.AvailabilityZones { + if strings.HasPrefix(zone, region) { + return zone + } + } + return "" +} + func awsLaunchCandidates(cfg Config) []string { if cfg.ServerTypeExplicit { return []string{cfg.ServerType} diff --git a/internal/cli/ssh_test.go b/internal/cli/ssh_test.go index 6fbf0f6..6432439 100644 --- a/internal/cli/ssh_test.go +++ b/internal/cli/ssh_test.go @@ -8,6 +8,7 @@ import ( "os" "os/exec" "path/filepath" + "reflect" "strings" "testing" "time" @@ -725,6 +726,24 @@ func TestAWSLaunchCandidatesAddsPolicyFallbackUnlessExact(t *testing.T) { } } +func TestAWSRegionAndAvailabilityZoneCandidates(t *testing.T) { + cfg := Config{ + AWSRegion: "eu-west-1", + Capacity: CapacityConfig{ + Regions: []string{"us-east-1", "eu-west-1"}, + AvailabilityZones: []string{"us-east-1a", "eu-west-1b"}, + }, + } + got := awsRegionCandidates(cfg, "eu-west-2") + want := []string{"eu-west-2", "eu-west-1", "us-east-1"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("awsRegionCandidates=%v want %v", got, want) + } + if zone := awsAvailabilityZoneForRegion(cfg, "eu-west-1"); zone != "eu-west-1b" { + t.Fatalf("awsAvailabilityZoneForRegion=%q want eu-west-1b", zone) + } +} + func TestRemoteSyncSanityReportsDeletionSample(t *testing.T) { got := remoteSyncSanity("/work/repo", false) for _, want := range []string{ diff --git a/worker/src/aws.ts b/worker/src/aws.ts index 17a065e..f2d6c89 100644 --- a/worker/src/aws.ts +++ b/worker/src/aws.ts @@ -76,7 +76,9 @@ export class EC2SpotClient { "Filter.2.Value.4": "stopped", }); return reservations(root).flatMap((reservation) => - items(record(record(reservation)["instancesSet"])["item"]).map(instanceToMachine), + items(record(record(reservation)["instancesSet"])["item"]).map((instance) => + this.withRegion(instanceToMachine(instance)), + ), ); } @@ -195,7 +197,7 @@ export class EC2SpotClient { }); for (const reservation of reservations(root)) { for (const instance of items(record(record(reservation)["instancesSet"])["item"])) { - return instanceToMachine(instance); + return this.withRegion(instanceToMachine(instance)); } } throw new Error(`aws instance not found: ${instanceID}`); @@ -359,6 +361,11 @@ export class EC2SpotClient { } params["Placement.HostId"] = hostID; params["Placement.Tenancy"] = "host"; + } else if (!subnetID) { + const availabilityZone = awsAvailabilityZoneForRegion(config, this.env, this.region); + if (availabilityZone) { + params["Placement.AvailabilityZone"] = availabilityZone; + } } addRunInstancesTagSpecifications(params, { ...labels, Name: name }, config.capacityMarket); const root = await this.ec2("RunInstances", params); @@ -366,7 +373,7 @@ export class EC2SpotClient { if (!instance) { throw new Error("aws returned no instances"); } - return instanceToMachine(instance); + return this.withRegion(instanceToMachine(instance)); } private async resolveAMI(config: LeaseConfig): Promise { @@ -574,6 +581,10 @@ export class EC2SpotClient { return undefined; } } + + private withRegion(server: ProviderMachine): ProviderMachine { + return { ...server, region: this.region }; + } } function awsSSHCIDRs(config: LeaseConfig, env: Env): string[] { @@ -699,6 +710,33 @@ export function awsLaunchCandidates( ]); } +export function awsRegionCandidates( + config: Pick, + env: Pick, + preferredRegion = "eu-west-1", +): string[] { + return uniqueStrings([ + preferredRegion, + config.awsRegion, + env.CRABBOX_AWS_REGION ?? "", + ...splitCommaList(env.CRABBOX_CAPACITY_REGIONS ?? ""), + ...config.capacityRegions, + ]); +} + +export function awsAvailabilityZoneForRegion( + config: Pick, + env: Pick, + region: string, +): string { + return ( + uniqueStrings([ + ...config.capacityAvailabilityZones, + ...splitCommaList(env.CRABBOX_CAPACITY_AVAILABILITY_ZONES ?? ""), + ]).find((zone) => zone.startsWith(region)) ?? "" + ); +} + export function applyAWSRunInstanceTargetOptions( params: Record, config: Pick, @@ -746,7 +784,23 @@ export function awsQuotaPreflightAttempt( } function uniqueStrings(values: string[]): string[] { - return [...new Set(values.filter(Boolean))]; + const out: string[] = []; + const seen = new Set(); + for (const value of values) { + const normalized = value.trim(); + if (normalized && !seen.has(normalized)) { + seen.add(normalized); + out.push(normalized); + } + } + return out; +} + +function splitCommaList(value: string): string[] { + return value + .split(",") + .map((item) => item.trim()) + .filter(Boolean); } function positiveInt(value: string | undefined): number { @@ -789,7 +843,7 @@ export function awsProvisioningErrorCategory(message: string): string { return ""; } -function isRetryableAWSProvisioningError(message: string): boolean { +export function isRetryableAWSProvisioningError(message: string): boolean { return awsProvisioningErrorCategory(message) !== ""; } diff --git a/worker/src/fleet.ts b/worker/src/fleet.ts index 222f5be..c6e7d8c 100644 --- a/worker/src/fleet.ts +++ b/worker/src/fleet.ts @@ -1,5 +1,10 @@ import { isAdminRequest } from "./auth"; -import { EC2SpotClient } from "./aws"; +import { + EC2SpotClient, + awsProvisioningErrorCategory, + awsRegionCandidates, + isRetryableAWSProvisioningError, +} from "./aws"; import { leaseConfig, validCIDRs } from "./config"; import { HetznerClient } from "./hetzner"; import { errorMessage, json, pathParts, readJson, requestOwner } from "./http"; @@ -516,6 +521,9 @@ export class FleetDurableObject implements DurableObject { ); record.cloudID = server.cloudID; record.serverType = serverType; + if (config.provider === "aws" && server.region) { + config.awsRegion = server.region; + } if (attempts && attempts.length > 0) { record.provisioningAttempts = attempts; } @@ -535,7 +543,7 @@ export class FleetDurableObject implements DurableObject { record.estimatedHourlyUSD = finalCost.hourlyUSD; record.maxEstimatedUSD = finalCost.maxUSD; if (config.provider === "aws") { - record.region = config.awsRegion; + record.region = server.region ?? config.awsRegion; } await this.putLease(record); await this.scheduleAlarm(); @@ -3124,8 +3132,13 @@ class HetznerProvider implements CloudProvider { class AWSProvider implements CloudProvider { private readonly client: EC2SpotClient; + private readonly region: string; - constructor(env: Env, region: string) { + constructor( + private readonly env: Env, + region: string, + ) { + this.region = region; this.client = new EC2SpotClient(env, region); } @@ -3139,21 +3152,46 @@ class AWSProvider implements CloudProvider { slug: string, owner: string, ): Promise<{ server: ProviderMachine; serverType: string; attempts?: ProvisioningAttempt[] }> { - const { server, serverType, attempts } = await this.client.createServerWithFallback( - config, - leaseID, - slug, - owner, - ); - const result: { - server: ProviderMachine; - serverType: string; - attempts?: ProvisioningAttempt[]; - } = { server: await this.client.waitForServerIP(server.cloudID), serverType }; - if (attempts && attempts.length > 0) { - result.attempts = attempts; + const regions = awsRegionCandidates(config, this.env, this.region); + const failures: string[] = []; + const regionAttempts: ProvisioningAttempt[] = []; + for (const region of regions) { + const client = region === this.region ? this.client : new EC2SpotClient(this.env, region); + try { + // oxlint-disable-next-line eslint/no-await-in-loop -- region fallback must preserve ordered capacity preference. + const { server, serverType, attempts } = await client.createServerWithFallback( + { ...config, awsRegion: region }, + leaseID, + slug, + owner, + ); + // oxlint-disable-next-line eslint/no-await-in-loop -- wait on the region that created the instance. + const readyServer = await client.waitForServerIP(server.cloudID); + const result: { + server: ProviderMachine; + serverType: string; + attempts?: ProvisioningAttempt[]; + } = { server: { ...readyServer, region }, serverType }; + const allAttempts = [...regionAttempts, ...(attempts ?? [])]; + if (allAttempts.length > 0) { + result.attempts = allAttempts; + } + return result; + } catch (error) { + const message = error instanceof Error ? error.message : String(error); + regionAttempts.push({ + serverType: config.serverType, + market: config.capacityMarket, + category: awsProvisioningErrorCategory(message) || "region", + message: `region ${region}: ${message}`, + }); + failures.push(`${region}: ${message}`); + if (!isRetryableAWSRegionProvisioningError(message)) { + break; + } + } } - return result; + throw new Error(failures.join("; ")); } async deleteServer(id: string): Promise { @@ -3172,7 +3210,20 @@ class AWSProvider implements CloudProvider { await this.client.deleteSSHKey(name); } - hourlyPriceUSD(serverType: string): Promise { - return this.client.hourlySpotPriceUSD(serverType); + hourlyPriceUSD( + serverType: string, + config: ReturnType, + ): Promise { + const region = config.awsRegion || this.region; + const client = region === this.region ? this.client : new EC2SpotClient(this.env, region); + return client.hourlySpotPriceUSD(serverType); } } + +function isRetryableAWSRegionProvisioningError(message: string): boolean { + return ( + isRetryableAWSProvisioningError(message) || + message.includes("quota ") || + message.includes("capacity") + ); +} diff --git a/worker/src/types.ts b/worker/src/types.ts index b59f61b..e4e7dbd 100644 --- a/worker/src/types.ts +++ b/worker/src/types.ts @@ -12,6 +12,8 @@ export interface Env { CRABBOX_AWS_ROOT_GB?: string; CRABBOX_AWS_SSH_CIDRS?: string; CRABBOX_AWS_MAC_HOST_ID?: string; + CRABBOX_CAPACITY_REGIONS?: string; + CRABBOX_CAPACITY_AVAILABILITY_ZONES?: string; CRABBOX_SHARED_TOKEN?: string; CRABBOX_ADMIN_TOKEN?: string; CRABBOX_SESSION_SECRET?: string; @@ -380,6 +382,7 @@ export interface ProviderMachine { provider: Provider; id: number; cloudID: string; + region?: string; name: string; status: string; serverType: string; diff --git a/worker/test/aws.test.ts b/worker/test/aws.test.ts index cf1229d..ef9d5e3 100644 --- a/worker/test/aws.test.ts +++ b/worker/test/aws.test.ts @@ -3,11 +3,13 @@ import { describe, expect, it } from "vitest"; import { addRunInstancesTagSpecifications, applyAWSRunInstanceTargetOptions, + awsAvailabilityZoneForRegion, awsInstanceTypeVCPUs, awsLaunchCandidates, awsProvisioningErrorCategory, awsQuotaCodeForMarket, awsQuotaPreflightAttempt, + awsRegionCandidates, createSecurityGroupParams, } from "../src/aws"; @@ -97,6 +99,23 @@ describe("aws provider", () => { ).not.toContain("t3.large"); }); + it("builds ordered AWS region and availability-zone candidates", () => { + expect( + awsRegionCandidates( + { awsRegion: "eu-west-1", capacityRegions: ["us-east-1", "eu-west-1"] }, + { CRABBOX_AWS_REGION: "eu-central-1", CRABBOX_CAPACITY_REGIONS: "us-west-2, us-east-1" }, + "eu-west-2", + ), + ).toEqual(["eu-west-2", "eu-west-1", "eu-central-1", "us-west-2", "us-east-1"]); + expect( + awsAvailabilityZoneForRegion( + { capacityAvailabilityZones: ["us-east-1a", "eu-west-1b"] }, + { CRABBOX_CAPACITY_AVAILABILITY_ZONES: "eu-west-2a,eu-west-1c" }, + "eu-west-1", + ), + ).toBe("eu-west-1b"); + }); + it("maps AWS instance types to vCPU quota units", () => { expect(awsInstanceTypeVCPUs("c7a.48xlarge")).toBe(192); expect(awsInstanceTypeVCPUs("c7a.xlarge")).toBe(4); diff --git a/worker/wrangler.jsonc b/worker/wrangler.jsonc index 0146e7f..2b5e052 100644 --- a/worker/wrangler.jsonc +++ b/worker/wrangler.jsonc @@ -15,6 +15,7 @@ "CRABBOX_ACCESS_TEAM_DOMAIN": "crabbox-openclaw.cloudflareaccess.com", "CRABBOX_ACCESS_AUD": "2c79b4c28dd40029b75b1e8d36d9a945ddc864dd40a34e50f6538bae8a3633ea", "CRABBOX_AWS_REGION": "eu-west-1", + "CRABBOX_CAPACITY_REGIONS": "eu-west-1,eu-west-2,eu-central-1,us-east-1,us-west-2", }, "routes": [ {