diff --git a/internal/cli/auth.go b/internal/cli/auth.go index 9ddc968..f1b727e 100644 --- a/internal/cli/auth.go +++ b/internal/cli/auth.go @@ -8,6 +8,7 @@ import ( "encoding/json" "fmt" "io" + "net/url" "os" "os/exec" "runtime" @@ -94,6 +95,17 @@ func (a App) loginWithGitHub(ctx context.Context, brokerURL, provider string, no if err != nil { return err } + if canonicalBrokerURL, ok := canonicalBrokerURLFromLoginURL(start.URL); ok && !sameBrokerURL(brokerURL, canonicalBrokerURL) { + brokerURL = canonicalBrokerURL + client, err = coordinatorClientForLogin(brokerURL) + if err != nil { + return err + } + start, err = client.StartGitHubLogin(ctx, pollSecretHash, provider) + if err != nil { + return err + } + } if noBrowser { fmt.Fprintf(a.Stderr, "open this GitHub login URL:\n%s\n", start.URL) } else if err := openBrowser(start.URL); err != nil { @@ -183,6 +195,46 @@ func coordinatorClientForLogin(brokerURL string) (*CoordinatorClient, error) { return coord, nil } +func canonicalBrokerURLFromLoginURL(loginURL string) (string, bool) { + u, err := url.Parse(loginURL) + if err != nil { + return "", false + } + redirect := u.Query().Get("redirect_uri") + if redirect == "" { + return "", false + } + redirectURL, err := url.Parse(redirect) + if err != nil || redirectURL.Scheme == "" || redirectURL.Host == "" { + return "", false + } + const callbackPath = "/v1/auth/github/callback" + cleanPath := strings.TrimRight(redirectURL.Path, "/") + if !strings.HasSuffix(cleanPath, callbackPath) { + return "", false + } + redirectURL.Path = strings.TrimRight(strings.TrimSuffix(cleanPath, callbackPath), "/") + redirectURL.RawPath = "" + redirectURL.RawQuery = "" + redirectURL.Fragment = "" + return strings.TrimRight(redirectURL.String(), "/"), true +} + +func sameBrokerURL(left, right string) bool { + return normalizedBrokerURL(left) == normalizedBrokerURL(right) +} + +func normalizedBrokerURL(value string) string { + u, err := url.Parse(value) + if err != nil { + return strings.TrimRight(value, "/") + } + u.Path = strings.TrimRight(u.Path, "/") + u.RawQuery = "" + u.Fragment = "" + return strings.TrimRight(u.String(), "/") +} + func openBrowser(target string) error { switch runtime.GOOS { case "darwin": diff --git a/internal/cli/auth_test.go b/internal/cli/auth_test.go index 2ea2d2e..a7e32d8 100644 --- a/internal/cli/auth_test.go +++ b/internal/cli/auth_test.go @@ -6,6 +6,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "net/url" "path/filepath" "strings" "testing" @@ -164,3 +165,126 @@ func TestGitHubLoginNoBrowserStoresReturnedToken(t *testing.T) { t.Fatalf("unexpected config: %#v", cfg) } } + +func TestGitHubLoginMigratesToCanonicalRedirectOrigin(t *testing.T) { + home := t.TempDir() + t.Setenv("HOME", home) + t.Setenv("XDG_CONFIG_HOME", filepath.Join(home, ".config")) + t.Setenv("CRABBOX_CONFIG", "") + t.Setenv("CRABBOX_COORDINATOR", "") + t.Setenv("CRABBOX_COORDINATOR_TOKEN", "") + t.Setenv("CRABBOX_PROVIDER", "") + + var seenPollSecretHash string + var canonicalStartCount int + var canonical *httptest.Server + canonical = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch r.URL.Path { + case "/v1/auth/github/start": + canonicalStartCount++ + var body struct { + PollSecretHash string `json:"pollSecretHash"` + Provider string `json:"provider"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatal(err) + } + if body.Provider != "aws" { + t.Fatalf("provider=%q", body.Provider) + } + seenPollSecretHash = body.PollSecretHash + _ = json.NewEncoder(w).Encode(CoordinatorGitHubLoginStart{ + LoginID: "login_canonical", + URL: githubAuthorizeURLForTest(canonical.URL), + ExpiresAt: time.Now().Add(time.Minute).Format(time.RFC3339), + }) + case "/v1/auth/github/poll": + var body struct { + LoginID string `json:"loginID"` + PollSecret string `json:"pollSecret"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatal(err) + } + if body.LoginID != "login_canonical" { + t.Fatalf("loginID=%q", body.LoginID) + } + if sha256Hex(body.PollSecret) != seenPollSecretHash { + t.Fatal("poll secret did not match canonical start hash") + } + _ = json.NewEncoder(w).Encode(CoordinatorGitHubLoginPoll{ + Status: "complete", + Token: "canonical-session-token", + Owner: "friend@example.com", + Org: "openclaw", + Login: "friend", + Provider: "aws", + }) + case "/v1/whoami": + if got := r.Header.Get("Authorization"); got != "Bearer canonical-session-token" { + t.Fatalf("authorization=%q", got) + } + _ = json.NewEncoder(w).Encode(CoordinatorWhoami{ + Owner: "friend@example.com", + Org: "openclaw", + Auth: "github", + }) + default: + http.NotFound(w, r) + } + })) + defer canonical.Close() + + var staleStartCount int + stale := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch r.URL.Path { + case "/v1/auth/github/start": + staleStartCount++ + _ = json.NewEncoder(w).Encode(CoordinatorGitHubLoginStart{ + LoginID: "login_stale", + URL: githubAuthorizeURLForTest(canonical.URL), + ExpiresAt: time.Now().Add(time.Minute).Format(time.RFC3339), + }) + case "/v1/auth/github/poll": + t.Fatal("poll should restart against canonical redirect origin") + default: + http.NotFound(w, r) + } + })) + defer stale.Close() + + var stdout, stderr bytes.Buffer + app := App{Stdout: &stdout, Stderr: &stderr} + if err := app.login(context.Background(), []string{"--url", stale.URL, "--provider", "aws", "--no-browser"}); err != nil { + t.Fatal(err) + } + if staleStartCount != 1 || canonicalStartCount != 1 { + t.Fatalf("start counts stale=%d canonical=%d", staleStartCount, canonicalStartCount) + } + if !strings.Contains(stderr.String(), "redirect_uri="+url.QueryEscape(canonical.URL+"/v1/auth/github/callback")) { + t.Fatalf("stderr=%q", stderr.String()) + } + if !strings.Contains(stdout.String(), "user=friend@example.com") { + t.Fatalf("stdout=%q", stdout.String()) + } + cfg, err := loadConfig() + if err != nil { + t.Fatal(err) + } + if cfg.Coordinator != canonical.URL || cfg.CoordToken != "canonical-session-token" || cfg.Provider != "aws" { + t.Fatalf("unexpected config: %#v", cfg) + } +} + +func TestCanonicalBrokerURLFromLoginURL(t *testing.T) { + got, ok := canonicalBrokerURLFromLoginURL("https://github.com/login/oauth/authorize?redirect_uri=https%3A%2F%2Fcrabbox.openclaw.ai%2Fv1%2Fauth%2Fgithub%2Fcallback&state=x") + if !ok || got != "https://crabbox.openclaw.ai" { + t.Fatalf("canonical=%q ok=%v", got, ok) + } +} + +func githubAuthorizeURLForTest(base string) string { + return "https://github.com/login/oauth/authorize?redirect_uri=" + url.QueryEscape(base+"/v1/auth/github/callback") + "&state=test" +} diff --git a/worker/src/index.ts b/worker/src/index.ts index 37b9c7a..3eed0be 100644 --- a/worker/src/index.ts +++ b/worker/src/index.ts @@ -14,6 +14,10 @@ export default { if (request.method === "GET" && url.pathname === "/") { return new Response(null, { status: 302, headers: { location: "/portal" } }); } + const canonicalPortal = canonicalPortalRedirect(request, env, url); + if (canonicalPortal) { + return canonicalPortal; + } if (url.pathname.startsWith("/v1/auth/")) { const id = env.FLEET.idFromName("default"); return env.FLEET.get(id).fetch(request); @@ -68,6 +72,28 @@ function isWebVNCAgentUpgrade(request: Request, url: URL): boolean { ); } +function canonicalPortalRedirect(request: Request, env: Env, url: URL): Response | undefined { + if ( + request.method !== "GET" || + request.headers.get("upgrade")?.toLowerCase() === "websocket" || + !url.pathname.startsWith("/portal") || + !env.CRABBOX_PUBLIC_URL + ) { + return undefined; + } + let publicURL: URL; + try { + publicURL = new URL(env.CRABBOX_PUBLIC_URL); + } catch { + return undefined; + } + if (url.origin === publicURL.origin) { + return undefined; + } + const location = new URL(`${url.pathname}${url.search}`, publicURL.origin); + return new Response(null, { status: 302, headers: { location: location.toString() } }); +} + function requestWithPortalCookie(request: Request): Request { if (request.headers.get("authorization")) { return request; diff --git a/worker/test/http.test.ts b/worker/test/http.test.ts index 8545123..8a86be4 100644 --- a/worker/test/http.test.ts +++ b/worker/test/http.test.ts @@ -1,6 +1,6 @@ import { describe, expect, it, vi } from "vitest"; -import { isAuthorized } from "../src"; +import coordinator, { isAuthorized } from "../src"; import { authenticateRequest, base64URL, @@ -8,6 +8,7 @@ import { requestWithAuthContext, } from "../src/auth"; import { requestOwner } from "../src/http"; +import type { Env } from "../src/types"; describe("coordinator auth", () => { it("denies requests when no shared token is configured", async () => { @@ -143,6 +144,39 @@ describe("coordinator auth", () => { expect(next.headers.get("cf-access-jwt-assertion")).toBeNull(); expect(requestOwner(next)).toBe("friend@example.com"); }); + + it("redirects browser portal auth routes to the configured public origin", async () => { + let fleetCalled = false; + const env = { + CRABBOX_PUBLIC_URL: "https://crabbox.openclaw.ai", + FLEET: { + idFromName: () => "default", + get: () => { + fleetCalled = true; + return { fetch: () => new Response("unexpected", { status: 599 }) }; + }, + }, + } as unknown as Env; + + const login = await coordinator.fetch( + new Request( + "https://crabbox-coordinator.steipete.workers.dev/portal/login?returnTo=%2Fportal%2Fleases%2Fcbx_1%2Fvnc", + ), + env, + ); + expect(login.status).toBe(302); + expect(login.headers.get("location")).toBe( + "https://crabbox.openclaw.ai/portal/login?returnTo=%2Fportal%2Fleases%2Fcbx_1%2Fvnc", + ); + + const logout = await coordinator.fetch( + new Request("https://crabbox-coordinator.steipete.workers.dev/portal/logout"), + env, + ); + expect(logout.status).toBe(302); + expect(logout.headers.get("location")).toBe("https://crabbox.openclaw.ai/portal/logout"); + expect(fleetCalled).toBe(false); + }); }); async function accessJwt(input: {